diff --git a/areal/experimental/openai/client.py b/areal/experimental/openai/client.py index 2ebba72199..44ef605927 100644 --- a/areal/experimental/openai/client.py +++ b/areal/experimental/openai/client.py @@ -127,6 +127,61 @@ def _normalize(item: Any): return normalized +def _normalize_messages_for_chat_template(messages: list[dict[str, Any]]) -> None: + """Normalize messages in-place to align with SGLang's preprocessing before + ``apply_chat_template``. + + Two normalizations are applied (mirroring SGLang's ``serving_chat.py``): + + 1. **Content flattening**: Many chat templates (e.g. Qwen3) only accept + ``message.content`` as a plain string (``{%- if message.content is string %}``). + List-formatted content like ``[{"type": "text", "text": "..."}]`` is silently + discarded. SGLang's ``process_content_for_template_format()`` detects the + template format and flattens content accordingly. We flatten single-text-part + lists to a plain string so the template renders them correctly. + + 2. **tool_calls arguments parsing**: SGLang (``serving_chat.py`` L449-465) parses + ``tool_calls[].function.arguments`` from JSON string to dict before passing to + ``apply_chat_template``. Jinja2's ``tojson`` filter uses ``sort_keys=True`` by + default, so a raw string vs a parsed dict produces different key ordering and + therefore different token sequences. + """ + for msg in messages: + if not isinstance(msg, dict): + continue + + # --- Content flattening --- + content = msg.get("content") + if content is not None and not isinstance(content, str): + if not isinstance(content, list): + content = list(content) + parts = [] + for part in content: + if not isinstance(part, dict): + part = ( + dict(part) + if hasattr(part, "items") + else {"type": "text", "text": str(part)} + ) + if "text" in part and "type" not in part: + part["type"] = "text" + parts.append(part) + if len(parts) == 1 and parts[0].get("type") == "text": + msg["content"] = parts[0]["text"] + else: + msg["content"] = parts + + # --- tool_calls arguments parsing --- + if msg.get("role") == "assistant" and isinstance(msg.get("tool_calls"), list): + for tool_call in msg["tool_calls"]: + func = tool_call.get("function", tool_call) + if isinstance(func.get("arguments"), str): + try: + func["arguments"] = json.loads(func["arguments"]) + except (json.JSONDecodeError, TypeError): + pass + + def _find_kth(lst: list, target, k: int) -> int: def target_indices(): for i, char in enumerate(lst): @@ -385,6 +440,7 @@ def concat_prompt_token_ids_with_parent( all_message_list += message_list + _normalize_messages_for_chat_template(all_message_list) all_tokens = tokenizer.apply_chat_template( all_message_list, tools=tools, @@ -606,6 +662,7 @@ async def create( has_images = len(image_data) > 0 tokenizer_messages = messages_for_tokenizer if has_images else messages_list + _normalize_messages_for_chat_template(tokenizer_messages) if self.chat_template_type == "hf": prompt_token_ids = self.tokenizer.apply_chat_template( tokenizer_messages, @@ -1010,6 +1067,7 @@ async def create( has_images = len(image_data) > 0 tokenizer_messages = messages_for_tokenizer if has_images else messages_list + _normalize_messages_for_chat_template(tokenizer_messages) if self.chat_template_type == "hf": prompt_token_ids = self.tokenizer.apply_chat_template( tokenizer_messages, diff --git a/examples/experimental/inference_service/batchmode/README.md b/examples/experimental/inference_service/batchmode/README.md new file mode 100644 index 0000000000..f181d1a748 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/README.md @@ -0,0 +1,224 @@ +# Inference Service Benchmark (Target Experiment) + +Measures AReaL inference service full-stack overhead on TAU²-bench agent tasks. + +``` +Request path: + OpenClaw CLI → IS Gateway (:30098) → Router (:8081) → DataProxy (:8082) → ArealOpenAI Client → SGLang /generate (:30000) + +Two SGLang instances: + Agent SGLang (port 30000) — benchmark target, --disable-radix-cache, TP=8 + User SGLang (port 30001) — simulates user, NOT measured, TP=8 +``` + +## Prerequisites + +| Dependency | Description | +| --------------------- | -------------------------------------------------------------- | +| Singularity container | AReaL dev image with SGLang, PyTorch, etc. | +| Model weights | Qwen3-235B-A22B-Instruct-2507 (local path) | +| tau2-bench | pip-installable TAU²-bench source | +| openclaw-benchmark | pip-installable OpenClaw TAU² integration package | +| Slurm cluster | 2 nodes × 8 GPUs (Agent + User SGLang) + 1 node for IS + sweep | + +## Step 1: Start SGLang Servers + +Edit `start_servers.sh` to set your container image, model path, and log directory, +then: + +```bash +bash start_servers.sh +``` + +This submits two Slurm jobs (Agent + User SGLang). Wait for both to start: + +```bash +squeue -u $(whoami) +# Note the node names from the NODELIST column, e.g.: +# Agent → node-A (port 30000) +# User → node-B (port 30001) +``` + +Verify servers are healthy: + +```bash +curl -sf http://:30000/v1/models | python3 -c "import sys,json; print(json.load(sys.stdin)['data'][0]['id'])" +curl -sf http://:30001/v1/models | python3 -c "import sys,json; print(json.load(sys.stdin)['data'][0]['id'])" +``` + +## Step 2: Configure sweep.sh + +Edit the top of `sweep.sh` to match your environment: + +```bash +# ── Paths (MUST update) ── +CONTAINER="" +PROJECT="" +AREAL_PATCH="" +MODEL_PATH="" + +# ── Endpoints (MUST update) ── +SGLANG_PORT=30000 # Agent SGLang, must be on same node as IS +USER_ENDPOINT="http://:30001/v1" # User SGLang node from Step 1 +``` + +## Step 3: Run the Sweep + +SSH into the **Agent SGLang node** (IS processes must co-locate with Agent SGLang on +localhost): + +```bash +ssh +``` + +Run the sweep: + +```bash +bash sweep.sh \ + "5,10,15,20,25,30" \ # concurrency levels + 50 \ # tasks per trial + 4 \ # trials per concurrency + reproduce # output tag +``` + +| Argument | Default | Description | +| -------- | ------------------ | -------------------------------------- | +| `$1` | `5,10,15,20,25,30` | Comma-separated concurrency levels | +| `$2` | `50` | Number of TAU²-bench tasks per trial | +| `$3` | `4` | Number of trials per concurrency level | +| `$4` | `` | Tag for output directory | + +The script automatically: + +1. Enters Singularity container +1. Installs dependencies (openclaw-benchmark, tau2-bench) +1. Patches IS code into container's AReaL installation +1. Starts Router → DataProxy → Gateway (registers DataProxy with Router) +1. Runs `collect_trajectories.py` for each (concurrency, trial) combination +1. Prints summary table on completion + +## What Happens Inside + +``` +┌─ Singularity Container ──────────────────────────────────────────────┐ +│ │ +│ Router (:8081) ←─ register ─ DataProxy (:8082) ←─ SGLang (:30000) +│ ↑ ↑ (localhost) +│ │ │ +│ Gateway (:30098) │ +│ ↑ ArealOpenAI +│ │ (tokenize → /generate) +│ collect_trajectories.py +│ │ +│ N × OpenClaw CLI (subprocess) +│ │ +│ worker.py (per task) +│ │ +│ tau2-bench orchestrator + evaluator +│ │ +│ User SGLang (remote, :30001) ────────────────────────── External Node +│ │ +└──────────────────────────────────────────────────────────────────────┘ +``` + +Per-task flow: + +1. `collect_trajectories.py` calls `POST /grant_capacity` then `POST /rl/start_session` + → gets session API key +1. Spawns `worker.py` subprocess with OpenClaw CLI pointed at Gateway +1. OpenClaw runs TAU²-bench task (multi-turn: agent calls tools via Gateway, user sim + via remote SGLang) +1. On completion, calls `POST /rl/set_reward` with task reward +1. Calls `POST /export_trajectories` → saves trajectory JSON to disk + +## Output + +Results are saved to `$PROJECT/trajectories/sweep_/`: + +``` +sweep_/ +├── c5/ +│ ├── trial_1/ +│ │ ├── collection_summary.json # pass rate, wall clock, tasks/min +│ │ ├── task_0_session_0-0.json # per-task trajectory +│ │ └── ... +│ ├── trial_2/ +│ └── ... +├── c10/ +└── ... +``` + +`collection_summary.json` fields: + +| Field | Description | +| ------------------- | --------------------------- | +| `completed` | Total tasks finished | +| `passed` / `failed` | Tasks with reward > 0 / = 0 | +| `errors` | Tasks that hit errors | +| `pass_rate` | passed / completed | +| `total_time_s` | Wall clock seconds | +| `tasks_per_min` | Throughput | + +## Configuration Reference + +### Benchmark Parameters (in collect_trajectories.py) + +| Parameter | Value | Description | +| -------------------- | --------- | ----------------------------- | +| `--domain` | `airline` | TAU²-bench domain | +| `--num-tasks` | `50` | Tasks per trial | +| `--max-steps` | `200` | Max agent turns per task | +| `--max-errors` | `10` | Max errors before abort | +| `--seed` | `300` | Random seed for task ordering | +| `--openclaw-timeout` | `3000` | Subprocess timeout (seconds) | + +### IS Component Ports + +| Component | Port | Flag | +| ------------ | ----- | -------------------- | +| Router | 8081 | `--port` | +| DataProxy | 8082 | `--port` | +| Gateway | 30098 | `--port` | +| Agent SGLang | 30000 | Must be on localhost | + +### SGLang Flags + +| Flag | Agent | User | Reason | +| ------------------------- | ----- | ---- | -------------------------------------------- | +| `--disable-radix-cache` | ✅ | ❌ | Consistent no-cache for IS benchmark | +| `--tool-call-parser` | ✅ | ✅ | Model-specific, e.g. `qwen25` for Qwen3 | +| `--enable-metrics` | ✅ | ✅ | Prometheus endpoint for `collect_metrics.py` | +| `--context-length 262144` | ✅ | ✅ | Qwen3-235B max context | +| `--tp 8` | ✅ | ✅ | Tensor parallelism across 8 GPUs | + +## Reference Results: Qwen3-235B-A22B-Instruct-2507 + +Tested on TAU²-bench airline domain, 50 tasks × 4 trials per concurrency, 2 nodes × +8×H200 GPUs. + +> Results below are from a single experiment run. Exact numbers may vary slightly across +> runs due to non-determinism in concurrent GPU inference and system scheduling. + +### Baseline B (OpenClaw → SGLang direct) vs Target (OpenClaw → IS → SGLang) + +Each cell: `Baseline B / Target (Δ)`. + +| Metric | c=5 | c=10 | c=15 | c=20 | c=25 | c=30 | +| ----------- | --------------------- | --------------------- | --------------------- | --------------------- | --------------------- | --------------------- | +| Pass@1 | 38% / 30% (-8pp) | 38% / 38% (+1pp) | 34% / 32% (-2pp) | 36% / 34% (-2pp) | 34% / 34% (0pp) | 36% / 38% (+3pp) | +| Avg E2E (s) | 4.82 / 4.57 (-5%) | 8.99 / 8.57 (-5%) | 13.05 / 12.58 (-4%) | 17.03 / 16.33 (-4%) | 20.73 / 20.16 (-3%) | 24.72 / 23.41 (-5%) | +| Input Tok/s | 15,207 / 16,017 (+5%) | 18,204 / 19,204 (+5%) | 18,820 / 20,138 (+7%) | 19,281 / 19,388 (+1%) | 19,433 / 20,474 (+5%) | 19,480 / 20,780 (+7%) | +| Req/s | 0.69 / 0.72 (+4%) | 0.82 / 0.87 (+6%) | 0.85 / 0.90 (+6%) | 0.87 / 0.87 (0%) | 0.87 / 0.93 (+7%) | 0.89 / 0.95 (+7%) | +| Tasks/min | 2.5 / 2.6 (+6%) | 3.0 / 3.2 (+5%) | 3.1 / 3.1 (+1%) | 3.0 / 3.0 (0%) | 3.0 / 3.4 (+13%) | 3.4 / 3.7 (+8%) | + +### Target SGLang Metrics (per concurrency) + +| Metric | c=5 | c=10 | c=15 | c=20 | c=25 | c=30 | +| ------------- | ------ | ------ | ------ | ------ | ------ | ------ | +| Input Tok/s | 16,017 | 19,204 | 20,138 | 19,388 | 20,474 | 20,780 | +| Output Tok/s | 75 | 88 | 89 | 88 | 97 | 93 | +| Avg E2E (s) | 4.57 | 8.57 | 12.58 | 16.33 | 20.16 | 23.41 | +| Avg TTFT (s) | 2.76 | 4.85 | 7.08 | 9.13 | 11.41 | 14.18 | +| Avg Queue (s) | 0.31 | 0.77 | 1.34 | 1.92 | 2.70 | 3.58 | +| Total Reqs | 3,318 | 3,332 | 3,492 | 3,539 | 3,309 | 3,162 | +| Avg InTok/Req | 22,099 | 22,040 | 22,283 | 22,324 | 22,132 | 21,766 | diff --git a/examples/experimental/inference_service/batchmode/collect_metrics.py b/examples/experimental/inference_service/batchmode/collect_metrics.py new file mode 100644 index 0000000000..606830e834 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/collect_metrics.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +"""Snapshot and diff SGLang Prometheus metrics. + +Usage: + # Snapshot current metrics + python collect_sglang_metrics.py snapshot http://127.0.0.1:30000 + + # Diff two snapshot files → JSON with deltas + derived throughput + python collect_sglang_metrics.py diff pre.json post.json --wall-clock 120 +""" + +import argparse +import json +import re +import sys +import urllib.request + +COUNTER_KEYS = [ + "sglang:prompt_tokens_total", + "sglang:generation_tokens_total", + "sglang:num_requests_total", + "sglang:e2e_request_latency_seconds_sum", + "sglang:e2e_request_latency_seconds_count", + "sglang:time_to_first_token_seconds_sum", + "sglang:time_to_first_token_seconds_count", + "sglang:queue_time_seconds_sum", + "sglang:queue_time_seconds_count", +] + + +def fetch_metrics(base_url: str) -> dict: + url = base_url.rstrip("/") + if "/v1" in url: + url = url.split("/v1")[0] + url = url + "/metrics" + + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=10) as resp: + text = resp.read().decode() + + metrics = {} + for line in text.splitlines(): + if line.startswith("#") or not line.strip(): + continue + m = re.match(r"([\w:]+)(?:\{([^}]*)\})?\s+([\d.eE+\-]+)", line) + if not m: + continue + name, labels, value = m.group(1), m.group(2) or "", float(m.group(3)) + key = f"{name}{{{labels}}}" if labels else name + metrics[key] = value + + return metrics + + +def extract_scalar(metrics: dict, prefix: str) -> float: + for key, val in metrics.items(): + if key.startswith(prefix) and "{" not in key: + return val + if key.startswith(prefix + "{"): + return val + return 0.0 + + +def snapshot(base_url: str) -> dict: + raw = fetch_metrics(base_url) + result = {} + for k in COUNTER_KEYS: + result[k] = extract_scalar(raw, k) + return result + + +def diff(pre: dict, post: dict, wall_clock: float) -> dict: + delta = {} + for k in COUNTER_KEYS: + delta[k] = post.get(k, 0) - pre.get(k, 0) + + prompt_tokens = delta.get("sglang:prompt_tokens_total", 0) + gen_tokens = delta.get("sglang:generation_tokens_total", 0) + total_tokens = prompt_tokens + gen_tokens + n_requests = delta.get("sglang:num_requests_total", 0) + e2e_sum = delta.get("sglang:e2e_request_latency_seconds_sum", 0) + e2e_count = delta.get("sglang:e2e_request_latency_seconds_count", 0) + ttft_sum = delta.get("sglang:time_to_first_token_seconds_sum", 0) + ttft_count = delta.get("sglang:time_to_first_token_seconds_count", 0) + queue_sum = delta.get("sglang:queue_time_seconds_sum", 0) + queue_count = delta.get("sglang:queue_time_seconds_count", 0) + + return { + "prompt_tokens": int(prompt_tokens), + "generation_tokens": int(gen_tokens), + "total_tokens": int(total_tokens), + "num_requests": int(n_requests), + "wall_clock_seconds": round(wall_clock, 1), + "input_throughput_tok_per_sec": round(prompt_tokens / wall_clock, 1) + if wall_clock > 0 + else 0, + "output_throughput_tok_per_sec": round(gen_tokens / wall_clock, 1) + if wall_clock > 0 + else 0, + "total_throughput_tok_per_sec": round(total_tokens / wall_clock, 1) + if wall_clock > 0 + else 0, + "request_throughput_per_sec": round(n_requests / wall_clock, 2) + if wall_clock > 0 + else 0, + "avg_prompt_tokens_per_req": round(prompt_tokens / n_requests, 1) + if n_requests > 0 + else 0, + "avg_gen_tokens_per_req": round(gen_tokens / n_requests, 1) + if n_requests > 0 + else 0, + "avg_e2e_latency_seconds": round(e2e_sum / e2e_count, 2) + if e2e_count > 0 + else 0, + "avg_ttft_seconds": round(ttft_sum / ttft_count, 3) if ttft_count > 0 else 0, + "avg_queue_time_seconds": round(queue_sum / queue_count, 3) + if queue_count > 0 + else 0, + "total_llm_time_seconds": round(e2e_sum, 1), + } + + +def monitor(base_url: str, interval: float, output_file: str): + """Poll SGLang /metrics every `interval` seconds, write CSV for peak throughput analysis.""" + import csv + import signal + import time + + fieldnames = [ + "timestamp", + "elapsed", + "prompt_tokens", + "generation_tokens", + "num_requests", + "running_reqs", + "queue_reqs", + "input_tok_per_sec", + "output_tok_per_sec", + ] + + stop = [False] + + def _handle(sig, frame): + stop[0] = True + + signal.signal(signal.SIGTERM, _handle) + signal.signal(signal.SIGINT, _handle) + + prev = None + prev_time = None + t0 = time.time() + + with open(output_file, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + + while not stop[0]: + try: + raw = fetch_metrics(base_url) + now = time.time() + cur_prompt = extract_scalar(raw, "sglang:prompt_tokens_total") + cur_gen = extract_scalar(raw, "sglang:generation_tokens_total") + cur_reqs = extract_scalar(raw, "sglang:num_requests_total") + running = extract_scalar(raw, "sglang:num_running_reqs") + queue = extract_scalar(raw, "sglang:num_queue_reqs") + + input_rate = 0.0 + output_rate = 0.0 + if prev is not None and prev_time is not None: + dt = now - prev_time + if dt > 0: + input_rate = (cur_prompt - prev["prompt_tokens"]) / dt + output_rate = (cur_gen - prev["generation_tokens"]) / dt + + row = { + "timestamp": round(now, 2), + "elapsed": round(now - t0, 1), + "prompt_tokens": int(cur_prompt), + "generation_tokens": int(cur_gen), + "num_requests": int(cur_reqs), + "running_reqs": int(running), + "queue_reqs": int(queue), + "input_tok_per_sec": round(input_rate, 1), + "output_tok_per_sec": round(output_rate, 1), + } + writer.writerow(row) + f.flush() + + prev = {"prompt_tokens": cur_prompt, "generation_tokens": cur_gen} + prev_time = now + except Exception: + pass + + time.sleep(interval) + + +def percentile(sorted_vals: list, p: float) -> float: + """Compute p-th percentile from a sorted list (0 <= p <= 100).""" + if not sorted_vals: + return 0.0 + k = (len(sorted_vals) - 1) * p / 100.0 + f = int(k) + c = f + 1 if f + 1 < len(sorted_vals) else f + d = k - f + return sorted_vals[f] + d * (sorted_vals[c] - sorted_vals[f]) + + +def analyze_monitor_csv(csv_path: str) -> dict: + """Analyze monitor CSV to extract throughput stats and concurrency distribution.""" + import csv as csvmod + + rows = [] + with open(csv_path) as f: + for r in csvmod.DictReader(f): + rows.append({k: float(v) for k, v in r.items()}) + + if len(rows) < 2: + return {} + + input_rates = [r["input_tok_per_sec"] for r in rows if r["input_tok_per_sec"] > 0] + output_rates = [ + r["output_tok_per_sec"] for r in rows if r["output_tok_per_sec"] > 0 + ] + running = [r["running_reqs"] for r in rows] + queued = [r["queue_reqs"] for r in rows] + + total_prompt = rows[-1]["prompt_tokens"] - rows[0]["prompt_tokens"] + total_gen = rows[-1]["generation_tokens"] - rows[0]["generation_tokens"] + total_reqs = rows[-1]["num_requests"] - rows[0]["num_requests"] + total_time = rows[-1]["elapsed"] - rows[0]["elapsed"] + + sorted_running = sorted(running) + sorted_queued = sorted(queued) + sorted_input = sorted(input_rates) if input_rates else [] + + return { + "monitor_duration_seconds": round(total_time, 1), + "monitor_samples": len(rows), + "total_prompt_tokens": int(total_prompt), + "total_generation_tokens": int(total_gen), + "total_requests": int(total_reqs), + # Throughput + "avg_input_throughput_tok_per_sec": round(total_prompt / total_time, 1) + if total_time > 0 + else 0, + "avg_output_throughput_tok_per_sec": round(total_gen / total_time, 1) + if total_time > 0 + else 0, + "peak_input_throughput_tok_per_sec": round(max(input_rates), 1) + if input_rates + else 0, + "peak_output_throughput_tok_per_sec": round(max(output_rates), 1) + if output_rates + else 0, + "p50_input_throughput_tok_per_sec": round(percentile(sorted_input, 50), 1) + if sorted_input + else 0, + "p95_input_throughput_tok_per_sec": round(percentile(sorted_input, 95), 1) + if sorted_input + else 0, + "p99_input_throughput_tok_per_sec": round(percentile(sorted_input, 99), 1) + if sorted_input + else 0, + # Concurrency distribution (running requests at SGLang) + "avg_running_reqs": round(sum(running) / len(running), 1) if running else 0, + "max_running_reqs": int(max(running)) if running else 0, + "min_running_reqs": int(min(running)) if running else 0, + "p50_running_reqs": round(percentile(sorted_running, 50), 1), + "p75_running_reqs": round(percentile(sorted_running, 75), 1), + "p95_running_reqs": round(percentile(sorted_running, 95), 1), + "p99_running_reqs": round(percentile(sorted_running, 99), 1), + # Queue distribution + "avg_queue_reqs": round(sum(queued) / len(queued), 1) if queued else 0, + "max_queue_reqs": int(max(queued)) if queued else 0, + "p50_queue_reqs": round(percentile(sorted_queued, 50), 1), + "p95_queue_reqs": round(percentile(sorted_queued, 95), 1), + "p99_queue_reqs": round(percentile(sorted_queued, 99), 1), + } + + +def main(): + parser = argparse.ArgumentParser() + sub = parser.add_subparsers(dest="cmd") + + snap_p = sub.add_parser("snapshot") + snap_p.add_argument("url", help="SGLang base URL (e.g. http://127.0.0.1:30000)") + snap_p.add_argument("-o", "--output", help="Output file (default: stdout)") + + diff_p = sub.add_parser("diff") + diff_p.add_argument("pre", help="Pre-snapshot JSON") + diff_p.add_argument("post", help="Post-snapshot JSON") + diff_p.add_argument( + "--wall-clock", type=float, required=True, help="Wall clock seconds" + ) + diff_p.add_argument("-o", "--output", help="Output file (default: stdout)") + + mon_p = sub.add_parser("monitor") + mon_p.add_argument("url", help="SGLang base URL") + mon_p.add_argument( + "-i", "--interval", type=float, default=5.0, help="Poll interval (default: 5s)" + ) + mon_p.add_argument("-o", "--output", required=True, help="Output CSV file") + + ana_p = sub.add_parser("analyze") + ana_p.add_argument("csv", help="Monitor CSV file") + ana_p.add_argument("-o", "--output", help="Output JSON file (default: stdout)") + + args = parser.parse_args() + + if args.cmd == "snapshot": + result = snapshot(args.url) + out = json.dumps(result, indent=2) + if args.output: + with open(args.output, "w") as f: + f.write(out) + else: + print(out) + + elif args.cmd == "diff": + with open(args.pre) as f: + pre = json.load(f) + with open(args.post) as f: + post = json.load(f) + result = diff(pre, post, args.wall_clock) + out = json.dumps(result, indent=2) + if args.output: + with open(args.output, "w") as f: + f.write(out) + else: + print(out) + + elif args.cmd == "monitor": + monitor(args.url, args.interval, args.output) + + elif args.cmd == "analyze": + result = analyze_monitor_csv(args.csv) + out = json.dumps(result, indent=2) + if args.output: + with open(args.output, "w") as f: + f.write(out) + else: + print(out) + + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/examples/experimental/inference_service/batchmode/collect_trajectories.py b/examples/experimental/inference_service/batchmode/collect_trajectories.py new file mode 100644 index 0000000000..77dcb91d34 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/collect_trajectories.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +""" +Online Trajectory Collector + +Runs tau2-bench tasks through AReaL Inference Service, collecting trajectories +with logprobs for RL training. Each concurrent worker gets its own IS session. + +Flow per worker: + 1. POST /rl/start_session → session_api_key + 2. Run tau2 task (OpenClaw CLI with dynamic api_key → IS Gateway → DataProxy → SGLang) + 3. POST /rl/set_reward {reward} + 4. POST /export_trajectories → save trajectory to disk + +Usage: + python collect_trajectories.py \ + --gateway-url http://127.0.0.1:30098 \ + --admin-api-key "dummy:0" \ + --user-endpoint http://:30001/v1 \ + --domain airline \ + --concurrency 5 \ + --num-tasks 50 \ + --max-steps 200 \ + --output-dir /storage/.../trajectories +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +import httpx + +try: + from tau2.data_model.tasks import Task # noqa: F401 + from tau2.evaluator.evaluator import EvaluationType # noqa: F401 + from tau2.run import load_tasks +except ImportError: + print("ERROR: tau2 not installed. Run: pip install -e /path/to/tau2-bench") + sys.exit(1) + +WORKER_SCRIPT = str(Path(__file__).resolve().parent / "worker.py") + + +async def grant_capacity( + client: httpx.AsyncClient, + gateway_url: str, + admin_key: str, +) -> None: + resp = await client.post( + f"{gateway_url}/grant_capacity", + headers={"Authorization": f"Bearer {admin_key}"}, + ) + resp.raise_for_status() + + +async def start_session( + client: httpx.AsyncClient, + gateway_url: str, + admin_key: str, + task_id: str, +) -> dict: + await grant_capacity(client, gateway_url, admin_key) + resp = await client.post( + f"{gateway_url}/rl/start_session", + json={"task_id": task_id}, + headers={"Authorization": f"Bearer {admin_key}"}, + ) + resp.raise_for_status() + return resp.json() + + +async def set_reward( + client: httpx.AsyncClient, + gateway_url: str, + session_api_key: str, + reward: float, +) -> dict: + resp = await client.post( + f"{gateway_url}/rl/set_reward", + json={"reward": reward}, + headers={"Authorization": f"Bearer {session_api_key}"}, + ) + resp.raise_for_status() + return resp.json() + + +async def export_trajectories( + client: httpx.AsyncClient, + gateway_url: str, + admin_key: str, + session_id: str, +) -> dict: + resp = await client.post( + f"{gateway_url}/export_trajectories", + json={"session_id": session_id}, + headers={"Authorization": f"Bearer {admin_key}"}, + ) + resp.raise_for_status() + return resp.json() + + +def _ensure_agent_home(agent_home: Path, gateway_url: str, api_key: str, model: str): + oc_dir = agent_home / ".openclaw" + (oc_dir / "workspace").mkdir(parents=True, exist_ok=True) + (oc_dir / "agents" / "main" / "agent").mkdir(parents=True, exist_ok=True) + oc_config = { + "models": { + "providers": { + "sglang": { + "baseUrl": gateway_url, + "apiKey": api_key, + "api": "openai-completions", + "models": [{"id": model, "name": model}], + } + } + } + } + with open(oc_dir / "openclaw.json", "w") as f: + json.dump(oc_config, f, indent=2) + + +def run_tau2_task_subprocess( + domain: str, + task_index: int, + gateway_url: str, + session_api_key: str, + user_endpoint: str, + model_name: str, + max_steps: int, + max_errors: int, + seed: int | None, + openclaw_cli: str, + openclaw_timeout: int, + worker_id: int, + work_dir: Path, +) -> dict: + """Run tau2 task via run_single_worker.py subprocess (same pattern as v2). + + HOME is set via env prefix so each worker gets an isolated OpenClaw home. + """ + agent_home = work_dir / f"agent_{worker_id}" + results_dir = work_dir / f"results_{worker_id}" + log_file = work_dir / f"worker_{worker_id}_task_{task_index}.log" + results_dir.mkdir(parents=True, exist_ok=True) + + env = os.environ.copy() + env["HOME"] = str(agent_home) + env["OPENCLAW_CLI_COMMAND"] = openclaw_cli + env["OPENCLAW_API_BASE"] = gateway_url + env["OPENCLAW_API_KEY"] = session_api_key + env["OPENCLAW_MODEL"] = model_name + env["OPENCLAW_TIMEOUT"] = str(openclaw_timeout) + env["OPENAI_API_BASE"] = user_endpoint + env["OPENAI_API_KEY"] = "dummy" + + cmd = [ + sys.executable, + WORKER_SCRIPT, + "--domain", + domain, + "--task-index", + str(task_index), + "--agent-endpoint", + gateway_url, + "--user-endpoint", + user_endpoint, + "--model", + model_name, + "--user-llm", + f"openai/{model_name}", + "--max-steps", + str(max_steps), + "--max-errors", + str(max_errors), + "--output-dir", + str(results_dir), + "--worker-id", + str(worker_id), + "--user-llm-args", + json.dumps({"temperature": 0.0}), + ] + if seed is not None: + cmd.extend(["--seed", str(seed)]) + + with open(log_file, "a") as lf: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=openclaw_timeout + 120, + env=env, + ) + lf.write(result.stderr) + + if not result.stdout.strip(): + raise RuntimeError( + f"Worker-{worker_id} returned no stdout (exit={result.returncode}). " + f"Check {log_file}" + ) + + return json.loads(result.stdout.strip()) + + +async def worker( + worker_id: int, + task_queue: asyncio.Queue, + gateway_url: str, + admin_key: str, + user_endpoint: str, + model_name: str, + domain: str, + max_steps: int, + max_errors: int, + seed: int | None, + openclaw_cli: str, + openclaw_timeout: int, + output_dir: Path, + work_dir: Path, + results: list, +): + async with httpx.AsyncClient(timeout=httpx.Timeout(3600.0)) as client: + while True: + try: + task_idx, task = task_queue.get_nowait() + except asyncio.QueueEmpty: + break + + task_start = time.time() + session_info = None + + try: + # 1. Start IS session + print( + f" [worker-{worker_id}] task={task.id} (idx={task_idx}) → start_session" + ) + session_info = await start_session( + client, gateway_url, admin_key, task.id + ) + session_api_key = session_info["api_key"] + session_id = session_info["session_id"] + print(f" [worker-{worker_id}] session={session_id[:12]}...") + + _ensure_agent_home( + work_dir / f"agent_{worker_id}", + gateway_url, + session_api_key, + model_name, + ) + + loop = asyncio.get_event_loop() + task_result = await loop.run_in_executor( + None, + run_tau2_task_subprocess, + domain, + task_idx, + gateway_url, + session_api_key, + user_endpoint, + model_name, + max_steps, + max_errors, + seed, + openclaw_cli, + openclaw_timeout, + worker_id, + work_dir, + ) + + reward = task_result.get("reward", 0.0) + + print( + f" [worker-{worker_id}] task={task.id} reward={reward} → set_reward" + ) + sr_resp = await set_reward(client, gateway_url, session_api_key, reward) + print( + f" [worker-{worker_id}] set_reward → {sr_resp.get('message', 'ok')}" + ) + + print(f" [worker-{worker_id}] task={task.id} → export_trajectories") + trajectory_data = await export_trajectories( + client, gateway_url, admin_key, session_id + ) + + traj_file = output_dir / f"task_{task.id}_session_{session_id}.json" + with open(traj_file, "w") as f: + json.dump( + { + "task_id": task.id, + "session_id": session_id, + "reward": reward, + "duration": time.time() - task_start, + "num_turns": task_result.get("num_steps", 0), + "termination_reason": task_result.get("termination_reason"), + "trajectory": trajectory_data, + }, + f, + indent=2, + default=str, + ) + + symbol = "✔" if reward > 0 else "✘" + elapsed = time.time() - task_start + print( + f" {symbol} [worker-{worker_id}] task={task.id} reward={reward} dur={elapsed:.1f}s saved={traj_file.name}" + ) + + results.append( + { + "task_id": task.id, + "session_id": session_id, + "reward": reward, + "duration": elapsed, + "status": "ok", + } + ) + + except Exception as e: + elapsed = time.time() - task_start + print( + f" ✗ [worker-{worker_id}] task={task.id} ERROR: {e} dur={elapsed:.1f}s" + ) + results.append( + { + "task_id": task.id, + "session_id": session_info["session_id"] + if session_info + else None, + "reward": 0.0, + "duration": elapsed, + "status": f"error: {e}", + } + ) + + finally: + task_queue.task_done() + + +async def run_collection(args): + """Main collection loop.""" + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + work_dir = output_dir / "workdir" + work_dir.mkdir(parents=True, exist_ok=True) + + # Load tau2 tasks + print(f"Loading {args.domain} tasks...") + tasks = load_tasks(task_set_name=args.domain) + if args.num_tasks != "all": + tasks = tasks[: int(args.num_tasks)] + print(f" {len(tasks)} tasks loaded") + + # Build task queue + task_queue: asyncio.Queue = asyncio.Queue() + for idx, task in enumerate(tasks): + task_queue.put_nowait((idx, task)) + + results: list = [] + + print("\nStarting trajectory collection:") + print(f" Gateway: {args.gateway_url}") + print(f" Domain: {args.domain}") + print(f" Tasks: {len(tasks)}") + print(f" Concurrency: {args.concurrency}") + print(f" Output: {output_dir}") + print() + + start_time = time.time() + + # Launch N workers + workers = [ + asyncio.create_task( + worker( + worker_id=i, + task_queue=task_queue, + gateway_url=args.gateway_url, + admin_key=args.admin_api_key, + user_endpoint=args.user_endpoint, + model_name=args.model, + domain=args.domain, + max_steps=args.max_steps, + max_errors=args.max_errors, + seed=args.seed, + openclaw_cli=args.openclaw_cli, + openclaw_timeout=args.openclaw_timeout, + output_dir=output_dir, + work_dir=work_dir, + results=results, + ) + ) + for i in range(args.concurrency) + ] + + await asyncio.gather(*workers) + + total_time = time.time() - start_time + n_pass = sum(1 for r in results if r["reward"] > 0) + n_fail = sum(1 for r in results if r["reward"] == 0 and r["status"] == "ok") + n_error = sum(1 for r in results if "error" in r["status"]) + + # Save summary + summary = { + "domain": args.domain, + "concurrency": args.concurrency, + "total_tasks": len(tasks), + "completed": len(results), + "passed": n_pass, + "failed": n_fail, + "errors": n_error, + "pass_rate": n_pass / max(len(results), 1), + "total_time_s": total_time, + "tasks_per_min": len(results) / (total_time / 60) if total_time > 0 else 0, + "avg_duration_s": sum(r["duration"] for r in results) / max(len(results), 1), + "results": results, + } + + summary_file = output_dir / "collection_summary.json" + with open(summary_file, "w") as f: + json.dump(summary, f, indent=2) + + print(f"\n{'=' * 60}") + print("Collection Complete") + print(f"{'=' * 60}") + print(f" Tasks: {len(results)}/{len(tasks)}") + print(f" Pass: {n_pass} ({summary['pass_rate']:.1%})") + print(f" Fail: {n_fail}") + print(f" Error: {n_error}") + print(f" Time: {total_time:.0f}s ({summary['tasks_per_min']:.1f} tasks/min)") + print(f" Output: {output_dir}") + print(f" Summary: {summary_file}") + + +def main(): + parser = argparse.ArgumentParser( + description="Collect trajectories via AReaL Inference Service" + ) + parser.add_argument( + "--gateway-url", + required=True, + help="IS Gateway URL (e.g. http://127.0.0.1:30098)", + ) + parser.add_argument("--admin-api-key", default="dummy:0", help="IS admin API key") + parser.add_argument("--user-endpoint", required=True, help="User sim LLM endpoint") + parser.add_argument( + "--model", default="Qwen3-235B-A22B-Instruct-2507", help="Model name" + ) + parser.add_argument("--domain", default="airline", help="tau2 domain") + parser.add_argument( + "--concurrency", type=int, default=5, help="Number of concurrent workers" + ) + parser.add_argument("--num-tasks", default="all", help="Number of tasks (or 'all')") + parser.add_argument("--max-steps", type=int, default=200, help="Max steps per task") + parser.add_argument( + "--max-errors", type=int, default=10, help="Max errors per task" + ) + parser.add_argument("--seed", type=int, default=300, help="Random seed") + parser.add_argument("--openclaw-cli", default="openclaw", help="OpenClaw CLI path") + parser.add_argument( + "--openclaw-timeout", + type=int, + default=600, + help="OpenClaw subprocess timeout (s)", + ) + parser.add_argument( + "--output-dir", required=True, help="Directory to save trajectories" + ) + args = parser.parse_args() + + asyncio.run(run_collection(args)) + + +if __name__ == "__main__": + main() diff --git a/examples/experimental/inference_service/batchmode/start_servers.sh b/examples/experimental/inference_service/batchmode/start_servers.sh new file mode 100755 index 0000000000..eca984dc92 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/start_servers.sh @@ -0,0 +1,133 @@ +#!/bin/bash +set -euo pipefail + +# ──────────────────────────────────────────────────────────────────────────── +# start_servers.sh — Launch Agent + User SGLang servers via Slurm +# +# Both servers use Qwen3-235B with production-verified flags. +# +# Usage: +# bash scripts/start_servers.sh # default 4 hours +# bash scripts/start_servers.sh --hours 8 # 8 hours +# bash scripts/start_servers.sh --agent-only # agent server only +# ──────────────────────────────────────────────────────────────────────────── + +CONTAINER="${CONTAINER}" +MODEL_PATH="${MODEL_PATH}" +MODEL_NAME="Qwen3-235B-A22B-Instruct-2507" +LOG_DIR="${PROJECT}/logs" + +AGENT_PORT=30000 +USER_PORT=30001 +HOURS=48 +CONTEXT_LENGTH=262144 +AGENT_ONLY=false + +while [[ $# -gt 0 ]]; do + case "$1" in + --hours) HOURS="$2"; shift 2 ;; + --agent-port) AGENT_PORT="$2"; shift 2 ;; + --user-port) USER_PORT="$2"; shift 2 ;; + --context-length) CONTEXT_LENGTH="$2"; shift 2 ;; + --agent-only) AGENT_ONLY=true; shift ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done + +mkdir -p "$LOG_DIR" + +# ════════════════════════════════════════════════════════════════════════════ +# Agent SGLang — benchmark target +# - --disable-radix-cache: simulate no-cache scenario for pressure testing +# - --tool-call-parser qwen25: REQUIRED for Qwen3 tool calling +# - --context-length 262144: Qwen3-235B max supported +# ════════════════════════════════════════════════════════════════════════════ +echo "════════════════════════════════════════════════════════════════" +echo " Agent SGLang: ${MODEL_NAME}" +echo " TP=8, port=${AGENT_PORT}, context=${CONTEXT_LENGTH}" +echo " Flags: --disable-radix-cache --tool-call-parser qwen25" +echo "════════════════════════════════════════════════════════════════" + +AGENT_JOBID=$(sbatch --parsable \ + --job-name=agent-sglang \ + --nodes=1 \ + --cpus-per-task=100 \ + --gres=gpu:8 \ + --mem=1500G \ + --time="${HOURS}:00:00" \ + --output="${LOG_DIR}/agent-sglang-%j.log" \ + --wrap "singularity exec --nv --no-home --writable-tmpfs --bind /storage:/storage ${CONTAINER} bash -c ' +python -m sglang.launch_server \ + --model-path ${MODEL_PATH} \ + --served-model-name ${MODEL_NAME} \ + --tp 8 \ + --port ${AGENT_PORT} \ + --host 0.0.0.0 \ + --context-length ${CONTEXT_LENGTH} \ + --tool-call-parser qwen25 \ + --disable-radix-cache \ + --enable-metrics \ + --enable-deterministic-inference +'") +echo " Job submitted: ${AGENT_JOBID}" + +# ════════════════════════════════════════════════════════════════════════════ +# User Sim SGLang — NOT benchmarked (radix cache ON) +# ════════════════════════════════════════════════════════════════════════════ +USER_JOBID="" +if [[ "$AGENT_ONLY" == "false" ]]; then + echo "" + echo "════════════════════════════════════════════════════════════════" + echo " User SGLang: ${MODEL_NAME}" + echo " TP=8, port=${USER_PORT}, context=${CONTEXT_LENGTH}" + echo " Flags: --tool-call-parser qwen25 (radix cache ON)" + echo "════════════════════════════════════════════════════════════════" + + USER_JOBID=$(sbatch --parsable \ + --job-name=user-sglang \ + --nodes=1 \ + --cpus-per-task=100 \ + --gres=gpu:8 \ + --mem=1500G \ + --time="${HOURS}:00:00" \ + --output="${LOG_DIR}/user-sglang-%j.log" \ + --wrap "singularity exec --nv --no-home --writable-tmpfs --bind /storage:/storage ${CONTAINER} bash -c ' +python -m sglang.launch_server \ + --model-path ${MODEL_PATH} \ + --served-model-name ${MODEL_NAME} \ + --tp 8 \ + --port ${USER_PORT} \ + --host 0.0.0.0 \ + --context-length ${CONTEXT_LENGTH} \ + --tool-call-parser qwen25 \ + --enable-metrics \ + --enable-deterministic-inference \ + --disable-radix-cache +'") + echo " Job submitted: ${USER_JOBID}" +fi + +# ════════════════════════════════════════════════════════════════════════════ +# Summary +# ════════════════════════════════════════════════════════════════════════════ +echo "" +echo "╔══════════════════════════════════════════════════════════════╗" +echo "║ SGLang Servers Submitted ║" +echo "╠══════════════════════════════════════════════════════════════╣" +echo "║ Agent: job=${AGENT_JOBID} (TP=8, port=${AGENT_PORT}, --disable-radix-cache)" +if [[ -n "$USER_JOBID" ]]; then +echo "║ User: job=${USER_JOBID} (TP=8, port=${USER_PORT}, radix-cache ON)" +fi +echo "╠══════════════════════════════════════════════════════════════╣" +echo "║ Model: ${MODEL_NAME}" +echo "║ Context: ${CONTEXT_LENGTH}" +echo "║ Hours: ${HOURS}" +echo "╠══════════════════════════════════════════════════════════════╣" +echo "║ Next steps: ║" +echo "║ 1. Wait for jobs: squeue -u \$(whoami) ║" +echo "║ 2. Get node names from squeue output ║" +echo "║ 3. Run benchmark: ║" +echo "║ bash scripts/srun_baseline.sh \\ ║" +echo "║ --agent-jobid ${AGENT_JOBID} \\ " +echo "║ --user-endpoint http://:${USER_PORT}/v1 " +echo "╚══════════════════════════════════════════════════════════════╝" diff --git a/examples/experimental/inference_service/batchmode/sweep.sh b/examples/experimental/inference_service/batchmode/sweep.sh new file mode 100755 index 0000000000..b81ae3391e --- /dev/null +++ b/examples/experimental/inference_service/batchmode/sweep.sh @@ -0,0 +1,159 @@ +#!/bin/bash +set -euo pipefail + +CONTAINER="${CONTAINER}" +PROJECT="${PROJECT}" +AREAL_PATCH="${AREAL_PATCH}" +MODEL_PATH="${MODEL_PATH}" + +ADMIN_KEY="dummy:0" +ROUTER_PORT=8081 +DATAPROXY_PORT=8082 +GATEWAY_PORT=30098 +SGLANG_PORT=30000 +USER_ENDPOINT="http://:30001/v1" + +CONCURRENCIES="${1:-5,10,15,20,25,30}" +NUM_TASKS="${2:-50}" +NUM_TRIALS="${3:-4}" +SWEEP_TAG="${4:-$(date +%Y%m%d_%H%M%S)}" +RESULTS_BASE="$PROJECT/trajectories/sweep_${SWEEP_TAG}" + +singularity exec --writable-tmpfs --no-home --nv \ + -B /storage/openpsi \ + "$CONTAINER" bash -c ' +set -euo pipefail +cd /AReaL && source .venv/bin/activate +export PATH="/root/.fnm/aliases/default/bin:$PATH" + +python3 -m ensurepip 2>/dev/null +python3 -m pip install -q pdm-backend toml litellm 2>/dev/null +python3 -m pip install -e '"$PROJECT"' 2>/dev/null +python3 -m pip install -e ${TAU2_DIR} 2>/dev/null + +cp -a '"$AREAL_PATCH"'/areal/experimental/inference_service/* /AReaL/areal/experimental/inference_service/ +cp -a '"$AREAL_PATCH"'/areal/experimental/openai/* /AReaL/areal/experimental/openai/ +cp '"$AREAL_PATCH"'/areal/api/cli_args.py /AReaL/areal/api/cli_args.py +find /AReaL/areal/experimental/inference_service /AReaL/areal/experimental/openai /AReaL/areal/api -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null + +which openclaw || { echo "ERROR: openclaw not found"; exit 1; } + +cd '"$PROJECT"' + +PIDS=() +cleanup() { + echo "Cleaning up IS processes..." + for pid in "${PIDS[@]}"; do kill "$pid" 2>/dev/null || true; done + wait 2>/dev/null || true +} +trap cleanup EXIT + +python3 -m areal.experimental.inference_service.router \ + --port '"$ROUTER_PORT"' --admin-api-key "'"$ADMIN_KEY"'" --log-level warning & +PIDS+=($!) +for i in $(seq 1 60); do + curl -sf --max-time 3 http://127.0.0.1:'"$ROUTER_PORT"'/health >/dev/null 2>&1 && { echo "✓ Router OK (${i}s)"; break; } + [ "$i" -eq 60 ] && { echo "✗ Router FAILED"; exit 1; } + sleep 1 +done + +python3 -m areal.experimental.inference_service.data_proxy \ + --port '"$DATAPROXY_PORT"' \ + --backend-addr http://127.0.0.1:'"$SGLANG_PORT"' \ + --backend-type sglang \ + --tokenizer-path '"$MODEL_PATH"' \ + --admin-api-key "'"$ADMIN_KEY"'" \ + --request-timeout 600 --log-level warning & +PIDS+=($!) +for i in $(seq 1 60); do + curl -sf --max-time 3 http://127.0.0.1:'"$DATAPROXY_PORT"'/health >/dev/null 2>&1 && { echo "✓ DataProxy OK (${i}s)"; break; } + [ "$i" -eq 60 ] && { echo "✗ DataProxy FAILED"; exit 1; } + sleep 1 +done + +curl -sf -X POST http://127.0.0.1:'"$ROUTER_PORT"'/register \ + -H "Authorization: Bearer '"$ADMIN_KEY"'" \ + -H "Content-Type: application/json" \ + -d '"'"'{"worker_addr": "http://127.0.0.1:'"$DATAPROXY_PORT"'"}'"'"' >/dev/null +echo "✓ DataProxy registered" + +python3 -m areal.experimental.inference_service.gateway \ + --port '"$GATEWAY_PORT"' \ + --router-addr http://127.0.0.1:'"$ROUTER_PORT"' \ + --admin-api-key "'"$ADMIN_KEY"'" \ + --forward-timeout 600 --log-level warning & +PIDS+=($!) +for i in $(seq 1 60); do + curl -sf --max-time 3 http://127.0.0.1:'"$GATEWAY_PORT"'/health >/dev/null 2>&1 && { echo "✓ Gateway OK (${i}s)"; break; } + [ "$i" -eq 60 ] && { echo "✗ Gateway FAILED"; exit 1; } + sleep 1 +done + +echo "" +echo "╔══════════════════════════════════════════════════════════╗" +echo "║ Inference Service ready ║" +echo "║ Sweep: concurrencies='"$CONCURRENCIES"' ║" +echo "║ Tasks: '"$NUM_TASKS"' × '"$NUM_TRIALS"' trials ║" +echo "║ Output: '"$RESULTS_BASE"' ║" +echo "╚══════════════════════════════════════════════════════════╝" +echo "" + +export TAU2_DATA_DIR=${TAU2_DIR}/data +SWEEP_START=$(date +%s) + +IFS="," read -ra CONCS <<< "'"$CONCURRENCIES"'" +for C in "${CONCS[@]}"; do + for TRIAL in $(seq 1 '"$NUM_TRIALS"'); do + RUN_DIR="'"$RESULTS_BASE"'/c${C}/trial_${TRIAL}" + echo "" + echo "================================================================" + echo " Concurrency=${C} Trial=${TRIAL} → ${RUN_DIR}" + echo "================================================================" + + python3 '"$PROJECT"'/scripts/collect_trajectories.py \ + --gateway-url http://127.0.0.1:'"$GATEWAY_PORT"' \ + --admin-api-key "'"$ADMIN_KEY"'" \ + --user-endpoint '"$USER_ENDPOINT"' \ + --model Qwen3-235B-A22B-Instruct-2507 \ + --domain airline \ + --concurrency "$C" \ + --num-tasks '"$NUM_TASKS"' \ + --max-steps 200 \ + --max-errors 10 \ + --seed 300 \ + --openclaw-cli $(which openclaw) \ + --openclaw-timeout 3000 \ + --output-dir "$RUN_DIR" || { + echo " ✗ FAILED: c=${C} trial=${TRIAL}" + continue + } + done +done + +SWEEP_END=$(date +%s) +SWEEP_DUR=$(( SWEEP_END - SWEEP_START )) + +echo "" +echo "╔══════════════════════════════════════════════════════════╗" +echo "║ Sweep Complete ║" +echo "║ Duration: $(( SWEEP_DUR / 3600 ))h $(( (SWEEP_DUR % 3600) / 60 ))m ║" +echo "║ Results: '"$RESULTS_BASE"' ║" +echo "╚══════════════════════════════════════════════════════════╝" + +python3 -c " +import json, os, glob + +base = \"'"$RESULTS_BASE"'\" +print() +print(\"Concurrency | Trial | Tasks | Pass | Fail | Error | Rate | Dur(s) | tasks/min\") +print(\"-\" * 85) +for summary_path in sorted(glob.glob(os.path.join(base, \"*/*/collection_summary.json\"))): + with open(summary_path) as f: + s = json.load(f) + parts = summary_path.split(\"/\") + c_dir = [p for p in parts if p.startswith(\"c\")][0] if any(p.startswith(\"c\") for p in parts) else \"?\" + t_dir = [p for p in parts if p.startswith(\"trial\")][0] if any(p.startswith(\"trial\") for p in parts) else \"?\" + print(f\"{c_dir:>11} | {t_dir:>5} | {s[\"completed\"]:>5} | {s[\"passed\"]:>4} | {s[\"failed\"]:>4} | {s[\"errors\"]:>5} | {s[\"pass_rate\"]:>5.1%} | {s[\"total_time_s\"]:>6.0f} | {s[\"tasks_per_min\"]:>9.1f}\") +print() +" +' diff --git a/examples/experimental/inference_service/batchmode/tau2/__init__.py b/examples/experimental/inference_service/batchmode/tau2/__init__.py new file mode 100644 index 0000000000..760a8ea588 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/__init__.py @@ -0,0 +1,55 @@ +import sys +import types + +from loguru import logger + +from .task_runner import run_task +from .task_runner_socket import run_task_with_socket_server +from .tau2_env import ( + EnvironmentSocketServer, + OpenClawEnvironmentEvaluator, + create_openclaw_tool_script, + evaluate_simulation_with_environment, +) + +__version__ = "0.2.0" + + +def register_openclaw_agent() -> bool: + try: + from tau2.registry import registry + + from .openclaw import OpenClawAgent + + registry.register_agent(OpenClawAgent, "openclaw_agent") + logger.info("Registered OpenClaw agent with TAU²: 'openclaw_agent'") + return True + except ImportError as exc: + logger.debug("TAU² not available, skipping registration: {}", exc) + except ValueError as exc: + if "already registered" in str(exc): + logger.debug("OpenClaw agent already registered") + return True + logger.warning("Failed to register OpenClaw agent: {}", exc) + except Exception as exc: + logger.warning("Unexpected error registering OpenClaw agent: {}", exc) + return False + + +def register_plugin() -> bool: + return register_openclaw_agent() + + +_register_module = types.ModuleType(f"{__name__}.register") +_register_module.register_openclaw_agent = register_openclaw_agent +_register_module.register_plugin = register_plugin +sys.modules.setdefault(_register_module.__name__, _register_module) +_registration_success = register_openclaw_agent() +__all__ = [ + "run_task", + "run_task_with_socket_server", + "evaluate_simulation_with_environment", + "OpenClawEnvironmentEvaluator", + "EnvironmentSocketServer", + "create_openclaw_tool_script", +] diff --git a/examples/experimental/inference_service/batchmode/tau2/openclaw/__init__.py b/examples/experimental/inference_service/batchmode/tau2/openclaw/__init__.py new file mode 100644 index 0000000000..66febb4466 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/openclaw/__init__.py @@ -0,0 +1,11 @@ +import sys +import types + +from .agent import OpenClawAgent +from .service import OpenClawConfig, OpenClawService +from .workspace_manager import OpenClawWorkspaceManager + +_config_module = types.ModuleType(f"{__name__}.config") +_config_module.OpenClawConfig = OpenClawConfig +sys.modules.setdefault(_config_module.__name__, _config_module) +__all__ = ["OpenClawAgent", "OpenClawService", "OpenClawWorkspaceManager", "OpenClawConfig"] diff --git a/examples/experimental/inference_service/batchmode/tau2/openclaw/agent.py b/examples/experimental/inference_service/batchmode/tau2/openclaw/agent.py new file mode 100644 index 0000000000..48c31cabc0 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/openclaw/agent.py @@ -0,0 +1,233 @@ +import json +import uuid +from typing import Any + +from loguru import logger +from pydantic import BaseModel + +try: + from tau2.agent.base import LocalAgent, ValidAgentInputMessage + from tau2.agent.llm_agent import ( + AGENT_INSTRUCTION, + is_valid_agent_history_message, + ) + from tau2.agent.llm_agent import ( + SYSTEM_PROMPT as AGENT_SYSTEM_PROMPT, + ) + from tau2.data_model.message import ( + APICompatibleMessage, + AssistantMessage, + Message, + SystemMessage, + ToolCall, + ToolMessage, + UserMessage, + ) + from tau2.environment.tool import Tool +except ImportError as exc: + logger.error( + "Failed to import tau2: {}\nPlease install tau2-bench: pip install -e ../tau2-bench", exc + ) + raise + +from .service import OpenClawConfig, OpenClawService +from .workspace_manager import OpenClawWorkspaceManager + + +class MessageTranslator: + @staticmethod + def to_openclaw(messages: list[Message]) -> list[dict[str, Any]]: + translated: list[dict[str, Any]] = [] + for msg in messages: + if isinstance(msg, UserMessage): + translated.append({"role": "user", "content": msg.content}) + elif isinstance(msg, ToolMessage): + translated.append( + {"role": "tool", "content": msg.content, "tool_call_id": msg.tool_call_id} + ) + elif isinstance(msg, SystemMessage): + translated.append({"role": "system", "content": msg.content}) + elif isinstance(msg, AssistantMessage): + payload = {"role": "assistant", "content": msg.content or ""} + if msg.tool_calls: + payload["tool_calls"] = [ + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": json.dumps(call.function.arguments), + }, + } + for call in msg.tool_calls + ] + translated.append(payload) + else: + logger.warning("Unknown message type: {}", type(msg)) + return translated + + @staticmethod + def from_openclaw(openclaw_msg: dict[str, Any]) -> AssistantMessage: + tool_calls = openclaw_msg.get("tool_calls") or None + return AssistantMessage( + role="assistant", + content=openclaw_msg.get("content") or None, + tool_calls=[ + ToolCall( + id=call["id"], + type="function", + function={ + "name": call["function"]["name"], + "arguments": json.loads(call["function"]["arguments"]) + if isinstance(call["function"]["arguments"], str) + else call["function"]["arguments"], + }, + ) + for call in tool_calls + ] + if tool_calls + else None, + ) + + +class OpenClawAgentState(BaseModel): + system_messages: list[SystemMessage] + messages: list[APICompatibleMessage] + openclaw_session_id: str | None = None + + +class OpenClawAgent(LocalAgent[OpenClawAgentState]): + def __init__( + self, + tools: list[Tool], + domain_policy: str, + socket_server_config: dict | None = None, + **kwargs, + ): + super().__init__(tools=tools, domain_policy=domain_policy) + self.socket_server_config = socket_server_config + self.config = OpenClawConfig.from_env() + self.message_translator = MessageTranslator() + self.service = OpenClawService( + cli_command=self.config.cli_command, + timeout=self.config.timeout, + api_base=self.config.api_base, + api_key=self.config.api_key, + model=self.config.model, + ) + self.workspace_manager: OpenClawWorkspaceManager | None = OpenClawWorkspaceManager( + cli_command=self.config.cli_command + ) + self.agent_id: str | None = f"tau2-{uuid.uuid4().hex[:8]}" + try: + self.workspace_manager.create_agent_workspace( + agent_id=self.agent_id, tools=tools, agent_name="TAU² Evaluation Agent" + ) + if self.socket_server_config: + self._inject_socket_tools(self.socket_server_config) + except Exception as exc: + logger.warning("Failed to create workspace, continuing without isolation: {}", exc) + self.workspace_manager = None + self.agent_id = None + logger.info( + "OpenClawAgent initialized: agent_id={} tools={} cli={} timeout={} socket_server={}", + self.agent_id or "default", + len(tools), + self.config.cli_command, + self.config.timeout, + bool(self.socket_server_config), + ) + + def _inject_socket_tools(self, server_config: dict) -> None: + if not self.workspace_manager or not self.agent_id: + return + from ..tau2_env import create_openclaw_tool_script + + tools_dir = self.workspace_manager.get_workspace_path(self.agent_id) / "socket_tools" + tools_dir.mkdir(exist_ok=True) + for tool in self.tools: + script_path = tools_dir / f"{tool.name}.py" + script_path.write_text( + create_openclaw_tool_script(tool_name=tool.name, server_config=server_config), + encoding="utf-8", + ) + script_path.chmod(0o755) + (tools_dir / "server_config.json").write_text( + json.dumps(server_config, indent=2), encoding="utf-8" + ) + logger.info("Injected {} socket tools into {}", len(self.tools), tools_dir) + + @property + def system_prompt(self) -> str: + return AGENT_SYSTEM_PROMPT.format( + domain_policy=self.domain_policy, agent_instruction=AGENT_INSTRUCTION + ) + ( + "\n\n## Available Tools\n\n" + "Before using tools, read `skills/tau2-tools/SKILL.md`. " + "Tools are exposed as Python scripts in `socket_tools/` and should be invoked like " + '`python socket_tools/.py \'{"param": "value"}\'`.' + ) + + def get_init_state(self, message_history: list[Message] | None = None) -> OpenClawAgentState: + message_history = message_history or [] + assert all(is_valid_agent_history_message(message) for message in message_history), ( + "Message history must contain only AssistantMessage, UserMessage, or ToolMessage to Agent." + ) + return OpenClawAgentState( + system_messages=[SystemMessage(role="system", content=self.system_prompt)], + messages=message_history, + openclaw_session_id=str(uuid.uuid4()), + ) + + def generate_next_message( + self, message: ValidAgentInputMessage, state: OpenClawAgentState + ) -> tuple[AssistantMessage, OpenClawAgentState]: + state.messages.append(message) + try: + response = self.service.chat( + messages=self.message_translator.to_openclaw( + state.system_messages + state.messages + ), + session_id=state.openclaw_session_id, + agent_id=self.agent_id, + ) + state.openclaw_session_id = response.get("session_id", state.openclaw_session_id) + assistant_msg = self.message_translator.from_openclaw(response["message"]) + except Exception: + logger.exception("Error in OpenClaw agent") + assistant_msg = AssistantMessage( + role="assistant", + content="I apologize, but I encountered an error processing your request. Please try again or rephrase your question.", + ) + state.messages.append(assistant_msg) + return assistant_msg, state + + def stop( + self, + message: ValidAgentInputMessage | None = None, + state: OpenClawAgentState | None = None, + ) -> None: + _ = message + cleanups = [] + if state and state.openclaw_session_id: + cleanups.append( + ("session", lambda: self.service.cleanup_session(state.openclaw_session_id)) + ) + if self.workspace_manager and self.agent_id: + cleanups.append( + ("workspace", lambda: self.workspace_manager.delete_agent_workspace(self.agent_id)) + ) + for label, cleanup in cleanups: + try: + cleanup() + except Exception as exc: + logger.warning("Failed to cleanup {}: {}", label, exc) + + @classmethod + def is_stop(cls, message: AssistantMessage) -> bool: + return False + + def set_seed(self, seed: int) -> None: + logger.warning( + "set_seed({}) called but OpenClaw may not support deterministic seeding", seed + ) diff --git a/examples/experimental/inference_service/batchmode/tau2/openclaw/service.py b/examples/experimental/inference_service/batchmode/tau2/openclaw/service.py new file mode 100644 index 0000000000..ed0251bd2b --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/openclaw/service.py @@ -0,0 +1,228 @@ +import json +import os +import subprocess +import time +from typing import Any + +from dotenv import load_dotenv +from loguru import logger +from pydantic import BaseModel + +load_dotenv() + + +def _message_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + return "\n".join( + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ) + return "" + + +class OpenClawConfig(BaseModel): + cli_command: str = "openclaw" + timeout: int = 600 + max_retries: int = 3 + api_base: str | None = None + api_key: str | None = None + model: str | None = None + + @classmethod + def from_env(cls) -> "OpenClawConfig": + env = os.getenv + return cls( + cli_command=env("OPENCLAW_CLI_COMMAND", "openclaw"), + timeout=int(env("OPENCLAW_TIMEOUT", "120")), + max_retries=int(env("OPENCLAW_MAX_RETRIES", "3")), + api_base=env("OPENCLAW_API_BASE"), + api_key=env("OPENCLAW_API_KEY"), + model=env("OPENCLAW_MODEL"), + ) + + class Config: + extra = "forbid" + + +class OpenClawServiceError(Exception): + pass + + +class OpenClawService: + def __init__( + self, + cli_command: str = "openclaw", + timeout: int = 60, + api_base: str | None = None, + api_key: str | None = None, + model: str | None = None, + ): + self.cli_command = cli_command + self.timeout = timeout + self.api_base = api_base + self.api_key = api_key + self.model = model + self._verify_cli() + + def _verify_cli(self) -> None: + try: + result = subprocess.run( + [self.cli_command, "--version"], capture_output=True, text=True, timeout=5 + ) + except FileNotFoundError: + logger.error("OpenClaw CLI not found: {}", self.cli_command) + return + except Exception as exc: + logger.error("Error checking OpenClaw CLI: {}", exc) + return + if result.returncode: + logger.warning("OpenClaw CLI check returned code {}", result.returncode) + return + version = result.stdout.strip().split() + logger.info("OpenClaw CLI available: {}", version[1] if len(version) > 1 else "unknown") + + def _build_message_text(self, messages: list[dict[str, Any]]) -> str: + system_parts, user_messages = [], [] + for message in messages: + text = _message_text(message.get("content", "")) + if not text.strip(): + continue + if message.get("role") == "system": + system_parts.append(text) + elif message.get("role") == "user": + user_messages.append(text) + return ( + "\n\n".join((*system_parts, "---", user_messages[0])) + if len(user_messages) == 1 and system_parts + else (user_messages[-1] if user_messages else "") + ) + + def _build_env(self) -> dict[str, str]: + env = os.environ.copy() + env.pop("OPENAI_BASE_URL", None) + env.pop("OPENAI_API_BASE", None) + if self.api_key: + env["OPENAI_API_KEY"] = self.api_key + return env + + def _parse_result( + self, result: subprocess.CompletedProcess[str], session_id: str + ) -> dict[str, Any]: + try: + output = json.loads(result.stdout) + except json.JSONDecodeError as exc: + logger.error("Failed to parse JSON: {}", result.stdout[:500]) + raise OpenClawServiceError(f"Invalid JSON response: {exc}") from exc + payloads = output.get("payloads") + if payloads is None and output.get("status") == "ok": + payloads = output.get("result", {}).get("payloads") + if payloads is None: + raise OpenClawServiceError( + f"OpenClaw returned error: {output.get('error', str(output)[:500])}" + ) + if not payloads: + logger.error( + "[No payloads] output_keys={} output_preview={} stderr_preview={}", + list(output), + json.dumps(output, default=str)[:1000], + (result.stderr or "")[:500], + ) + raise OpenClawServiceError("No payloads in response") + payload = payloads[0] + if not payload.get("text", ""): + logger.warning( + "[Empty text] payload={} output_preview={}", + json.dumps(payload, default=str)[:500], + json.dumps(output, default=str)[:500], + ) + return { + "message": { + "role": "assistant", + "content": payload.get("text", ""), + "tool_calls": None, + }, + "session_id": session_id, + "raw_output": output, + } + + def _retry_enoent(self, error: Any, attempt: int, attempts: int) -> bool: + if "ENOENT" not in str(error) or attempt >= attempts - 1: + return False + logger.warning("Session file not found, retrying in 1.0s ({}/{})", attempt + 1, attempts) + time.sleep(1.0) + return True + + def chat( + self, + messages: list[dict[str, Any]], + session_id: str | None = None, + agent_id: str | None = None, + ) -> dict[str, Any]: + message_text = self._build_message_text(messages) + if not message_text: + raise OpenClawServiceError("No user message found in messages") + session_id = session_id or "tau2-bench-session" + cmd = [ + self.cli_command, + "agent", + "--message", + message_text, + "--session-id", + session_id, + "--local", + "--json", + ] + if agent_id: + cmd.extend(["--agent", agent_id]) + logger.info( + "[OpenClaw CLI] Running: {} | api_base={} model={}", + " ".join(cmd), + self.api_base, + self.model, + ) + attempts = 3 + last_error: Any = None + for attempt in range(attempts): + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=self.timeout, env=self._build_env() + ) + if not result.returncode: + return self._parse_result(result, session_id) + last_error = result.stderr or result.stdout + if self._retry_enoent(last_error, attempt, attempts): + continue + raise OpenClawServiceError(f"OpenClaw CLI failed: {last_error}") + except subprocess.TimeoutExpired as exc: + self._cleanup_lock_files(agent_id) + raise OpenClawServiceError(f"OpenClaw CLI timeout after {self.timeout}s") from exc + except OpenClawServiceError: + raise + except Exception as exc: + last_error = exc + if self._retry_enoent(exc, attempt, attempts): + continue + logger.exception("Error calling OpenClaw CLI") + raise OpenClawServiceError(f"OpenClaw CLI call failed: {exc}") from exc + raise OpenClawServiceError(f"OpenClaw CLI failed after {attempts} attempts: {last_error}") + + def _cleanup_lock_files(self, agent_id: str | None) -> None: + if not agent_id: + return + try: + from .workspace_manager import OpenClawWorkspaceManager + + for lock_path in ( + OpenClawWorkspaceManager(cli_command=self.cli_command).get_workspace_path(agent_id) + / ".openclaw" + ).rglob("*.lock"): + lock_path.unlink(missing_ok=True) + logger.warning("Cleaned up stale lock: {}", lock_path) + except Exception as exc: + logger.debug("Lock cleanup failed (non-fatal): {}", exc) + + def cleanup_session(self, session_id: str) -> None: + logger.debug("Session cleanup not needed for CLI mode: {}", session_id) diff --git a/examples/experimental/inference_service/batchmode/tau2/openclaw/workspace_manager.py b/examples/experimental/inference_service/batchmode/tau2/openclaw/workspace_manager.py new file mode 100644 index 0000000000..53d7beab13 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/openclaw/workspace_manager.py @@ -0,0 +1,201 @@ +import json +import os +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +from loguru import logger + +from tau2.environment.tool import Tool + + +class WorkspaceManagerError(Exception): + pass + + +class OpenClawWorkspaceManager: + def __init__( + self, + cli_command: str = "openclaw", + base_workspace_dir: Path | None = None, + llm_api_key: str | None = None, + llm_base_url: str | None = None, + llm_model: str | None = None, + ): + self.cli_command = cli_command + self.llm_api_key = llm_api_key or os.environ.get("OPENCLAW_API_KEY") + self.llm_base_url = llm_base_url or os.environ.get("OPENCLAW_API_BASE") + self.llm_model = llm_model or os.environ.get("OPENCLAW_MODEL", "gpt-4o") + self.base_workspace_dir = Path( + base_workspace_dir or Path(tempfile.gettempdir()) / "openclaw-tau2-workspaces" + ) + self.base_workspace_dir.mkdir(parents=True, exist_ok=True) + self.created_agents: set[str] = set() + logger.info("Workspace manager initialized: {}", self.base_workspace_dir) + + @property + def openclaw_dir(self) -> Path: + return Path.home() / ".openclaw" + + @staticmethod + def _write_json(path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + def _provider_config(self) -> dict[str, Any]: + model_name = self.llm_model or "model" + return { + "baseUrl": self.llm_base_url or "http://127.0.0.1:30000", + "apiKey": self.llm_api_key or "dummy", + "api": "openai-completions", + "models": [{"id": model_name, "name": model_name}], + } + + def _tool_markdown(self, tool: Tool) -> str: + schema = tool.openai_schema.get("function", {}) + params = schema.get("parameters", {}) + properties = params.get("properties", {}) + required = set(params.get("required", [])) + parameter_lines = "\n".join( + f"- `{name}` ({info.get('type', 'string')}, {'required' if name in required else 'optional'}): {info.get('description', '')}" + for name, info in properties.items() + ) + example = json.dumps({name: f"<{name}>" for name in properties if name in required} or {}) + return "\n".join( + filter( + None, + [ + f"### {tool.name}", + "", + schema.get("description", tool.short_desc), + "", + f"**Script**: `socket_tools/{tool.name}.py`", + "", + "**Parameters**:" if properties else "", + parameter_lines, + "" if properties else "", + "**Example**:", + "```bash", + f"python socket_tools/{tool.name}.py '{example}'", + "```", + "", + ], + ) + ) + + def _build_skill_markdown(self, tools: list[Tool]) -> str: + header = '---\nname: tau2-tools\ndescription: TAU²-Bench Environment Tools - Execute tools via Socket Server\nmetadata: {"openclaw":{"emoji":"🔧","requires":{"bins":["python"]}}}\n---\n\n# TAU²-Bench Environment Tools\n\nUse `python socket_tools/.py \'\'` to call the shared environment.\nGenerated scripts live in `socket_tools/` and require only the Python standard library.\n\n## Available Tools\n' + footer = "## Notes\n\n- Socket tools are generated automatically when the socket server is enabled.\n- Check `socket_tools/server_config.json` for connection details.\n- Changes are applied to the shared environment immediately.\n" + body = "\n".join(self._tool_markdown(tool) for tool in tools) + return f"{header}\n{body}\n{footer}" + + def _setup_workspace_tools(self, workspace_path: Path, tools: list[Tool]) -> None: + skill_dir = workspace_path / "skills" / "tau2-tools" + skill_dir.mkdir(parents=True, exist_ok=True) + (skill_dir / "SKILL.md").write_text(self._build_skill_markdown(tools), encoding="utf-8") + tools_dir = workspace_path / "tools" + tools_dir.mkdir(exist_ok=True) + self._write_json( + tools_dir / "tau2_tools.json", + { + "version": "1.0", + "tools": [tool.openai_schema for tool in tools], + "metadata": { + "source": "tau2-bench-socket-server", + "count": len(tools), + "socket_tools_dir": "socket_tools/", + }, + }, + ) + + def create_agent_workspace( + self, agent_id: str, tools: list[Tool], agent_name: str | None = None + ) -> str: + workspace_path = self.get_workspace_path(agent_id) + agent_dir = self.openclaw_dir / "agents" / agent_id / "agent" + workspace_path.mkdir(parents=True, exist_ok=True) + try: + agent_dir.mkdir(parents=True, exist_ok=True) + (self.openclaw_dir / "workspace").mkdir(exist_ok=True) + self._write_json( + self.openclaw_dir / "openclaw.json", + { + "models": {"providers": {"sglang": self._provider_config()}}, + "agents": { + "defaults": {"maxConcurrent": 4}, + "list": [ + {"id": "main"}, + { + "id": agent_id, + "name": agent_name or agent_id, + "workspace": str(workspace_path), + "agentDir": str(agent_dir), + }, + ], + }, + }, + ) + self.created_agents.add(agent_id) + self._setup_workspace_tools(workspace_path, tools) + except Exception as exc: + logger.exception("Error creating agent workspace") + raise WorkspaceManagerError(f"Failed to create agent: {exc}") from exc + logger.info("Agent '{}' created with {} tools", agent_id, len(tools)) + return agent_id + + def delete_agent_workspace(self, agent_id: str) -> None: + try: + result = subprocess.run( + [self.cli_command, "agents", "delete", agent_id, "--force"], + capture_output=True, + text=True, + timeout=30, + input="y\n", + ) + if result.returncode: + logger.warning( + "Failed to delete agent '{}': {}", agent_id, result.stderr or result.stdout + ) + except Exception as exc: + logger.debug("Agent deletion via CLI failed for '{}': {}", agent_id, exc) + shutil.rmtree(self.get_workspace_path(agent_id), ignore_errors=True) + shutil.rmtree(self.openclaw_dir / "agents" / agent_id, ignore_errors=True) + self.created_agents.discard(agent_id) + logger.info("Agent '{}' deleted", agent_id) + + def cleanup_all(self) -> None: + for agent_id in tuple(self.created_agents): + self.delete_agent_workspace(agent_id) + try: + self.base_workspace_dir.rmdir() + except OSError: + pass + except Exception as exc: + logger.error("Error removing base workspace: {}", exc) + + def list_agents(self) -> list[dict[str, Any]]: + try: + result = subprocess.run( + [self.cli_command, "agents", "list", "--json"], + capture_output=True, + text=True, + timeout=10, + ) + if result.returncode: + logger.warning("Failed to list agents: {}", result.stderr) + return [] + return json.loads(result.stdout) + except Exception as exc: + logger.error("Error listing agents: {}", exc) + return [] + + def get_workspace_path(self, agent_id: str) -> Path: + return self.base_workspace_dir / agent_id + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup_all() diff --git a/examples/experimental/inference_service/batchmode/tau2/pyproject.toml b/examples/experimental/inference_service/batchmode/tau2/pyproject.toml new file mode 100644 index 0000000000..f987eff80d --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/pyproject.toml @@ -0,0 +1,82 @@ +[project] +name = "openclaw-tau2-integration" +version = "0.1.0" +description = "Integration project for testing OpenClaw agent on TAU²-Bench" +authors = [ + {name = "Agent Benchmark Team", email = "dev@example.com"} +] +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} + +dependencies = [ + # 核心依赖 + "pydantic>=2.0.0", + "pyyaml>=6.0", + "loguru>=0.7.0", + "requests>=2.31.0", + "python-dotenv>=1.0.0", + + # TAU²-Bench - 从本地安装 + # 在安装时使用: pip install -e ".[tau2]" +] + +[project.optional-dependencies] +# TAU²-Bench 依赖 +tau2 = [ + "tau2 @ git+ssh://git@github.com/sierra-research/tau2-bench.git" +] + +# 开发依赖 +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "pytest-asyncio>=0.21.0", + "black>=23.0.0", + "ruff>=0.1.0", + "mypy>=1.0.0", + "ipython>=8.0.0", +] + +# Entry point for tau2 plugin system +[project.entry-points."tau2.plugins"] +openclaw = "openclaw_tau2.register:register_plugin" + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["."] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["."] +addopts = "-v" + +[tool.black] +line-length = 100 +target-version = ['py310'] +include = '\.pyi?$' +extend-exclude = ''' +/( + \.git + | \.venv + | build + | dist +)/ +''' + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] +ignore = ["E501"] # line too long (handled by black) + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = false diff --git a/examples/experimental/inference_service/batchmode/tau2/task_runner.py b/examples/experimental/inference_service/batchmode/tau2/task_runner.py new file mode 100644 index 0000000000..b6cc4cfe94 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/task_runner.py @@ -0,0 +1,83 @@ +import sys + +from loguru import logger + +from tau2.data_model.simulation import SimulationRun +from tau2.data_model.tasks import Task +from tau2.evaluator.evaluator import EvaluationType + +from .task_runner_socket import ( + _build_cli_parser, + _configure_logging, + _load_selected_task, + _log_simulation_summary, + _run_task_impl, +) + + +def run_task( + domain: str, + task: Task, + agent: str, + user: str, + llm_agent: str | None = None, + llm_args_agent: dict | None = None, + llm_user: str | None = None, + llm_args_user: dict | None = None, + max_steps: int = 100, + max_errors: int = 10, + evaluation_type: EvaluationType = EvaluationType.ALL, + seed: int | None = None, + enforce_communication_protocol: bool = False, +) -> SimulationRun: + return _run_task_impl( + domain, + task, + agent, + user, + llm_agent, + llm_args_agent, + llm_user, + llm_args_user, + max_steps, + max_errors, + evaluation_type, + seed, + enforce_communication_protocol, + use_socket_server=False, + initialize_environment=False, + save_simulation=False, + )[0] + + +def main() -> int: + args = _build_cli_parser("Test OpenClaw agent with TAU² tasks", 20).parse_args() + _configure_logging(args.log_level) + task = _load_selected_task(args.domain, args.task_id) + if task is None: + return 1 + try: + simulation = run_task( + args.domain, + task, + args.agent, + args.user, + args.llm_agent, + None, + args.llm_user, + None, + args.max_steps, + 5, + EvaluationType.ALL, + args.seed, + False, + ) + except Exception: + logger.exception("Error running task") + return 1 + _log_simulation_summary(simulation, debug=args.log_level == "DEBUG") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/experimental/inference_service/batchmode/tau2/task_runner_socket.py b/examples/experimental/inference_service/batchmode/tau2/task_runner_socket.py new file mode 100644 index 0000000000..113b4703ac --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/task_runner_socket.py @@ -0,0 +1,366 @@ +import inspect +import json +import sys +from pathlib import Path + +from loguru import logger + +from tau2.data_model.simulation import SimulationRun +from tau2.data_model.tasks import Task +from tau2.environment.environment import Environment +from tau2.evaluator.evaluator import EvaluationType +from tau2.orchestrator.orchestrator import Orchestrator +from tau2.registry import registry + +from .tau2_env import EnvironmentSocketServer, evaluate_simulation_with_environment + + +def _initial_state_kwargs(task: Task) -> dict: + state = task.initial_state + return { + "initialization_data": getattr(state, "initialization_data", None), + "initialization_actions": getattr(state, "initialization_actions", None), + "message_history": getattr(state, "message_history", None) or [], + } + + +def _save_simulation(domain: str, task: Task, agent: str, simulation: SimulationRun) -> None: + output_file = Path("simulation_results") / f"simulation_{domain}_{task.id}_{agent}.json" + output_file.parent.mkdir(exist_ok=True) + try: + payload = ( + simulation.model_dump() + if hasattr(simulation, "model_dump") + else simulation.dict() + if hasattr(simulation, "dict") + else simulation.__dict__ + ) + output_file.write_text( + json.dumps(payload, indent=2, ensure_ascii=False, default=str), encoding="utf-8" + ) + logger.info("Simulation saved to: {}", output_file) + except Exception as exc: + logger.error("Failed to save simulation to file: {}", exc) + + +def _build_agent( + task: Task, + agent: str, + llm_agent: str | None, + llm_args_agent: dict | None, + environment: Environment, + environment_constructor, + env_server: EnvironmentSocketServer | None, + server_config: dict | None, + init_kwargs: dict, +): + from tau2.agent.llm_agent import LLMAgent, LLMGTAgent, LLMSoloAgent + + try: + from tau2.gym.gym_agent import GymAgent + except ImportError: + GymAgent = None # noqa: N806 + agent_constructor = registry.get_agent_constructor(agent) + tools = environment.get_tools() + policy = environment.get_policy() + if issubclass(agent_constructor, LLMAgent): + return ( + agent_constructor( + tools=tools, domain_policy=policy, llm=llm_agent, llm_args=llm_args_agent + ), + environment, + False, + ) + if issubclass(agent_constructor, LLMGTAgent): + return ( + agent_constructor( + tools=tools, domain_policy=policy, llm=llm_agent, llm_args=llm_args_agent, task=task + ), + environment, + False, + ) + if issubclass(agent_constructor, LLMSoloAgent): + solo_environment = environment_constructor(solo_mode=True) + if task.initial_state: + solo_environment.set_state(**init_kwargs) + if env_server: + env_server.environment = solo_environment + user_tools = solo_environment.get_user_tools() if solo_environment.user_tools else [] + return ( + agent_constructor( + tools=solo_environment.get_tools() + user_tools, + domain_policy=solo_environment.get_policy(), + llm=llm_agent, + llm_args=llm_args_agent, + task=task, + ), + solo_environment, + True, + ) + if GymAgent is not None and issubclass(agent_constructor, GymAgent): + return agent_constructor(tools=tools, domain_policy=policy), environment, False + if agent == "openclaw_agent": + instance = agent_constructor( + tools=tools, + domain_policy=policy, + socket_server_config=server_config if env_server else None, + ) + logger.info("OpenClaw agent created with socket config: {}", server_config) + return instance, environment, False + return agent_constructor(tools=tools, domain_policy=policy), environment, False + + +def _run_task_impl( + domain: str, + task: Task, + agent: str, + user: str, + llm_agent: str | None, + llm_args_agent: dict | None, + llm_user: str | None, + llm_args_user: dict | None, + max_steps: int, + max_errors: int, + evaluation_type: EvaluationType, + seed: int | None, + enforce_communication_protocol: bool, + *, + use_socket_server: bool, + initialize_environment: bool, + save_simulation: bool, + socket_port: int | None = None, +) -> tuple[SimulationRun, Environment]: + if max_steps <= 0: + raise ValueError("Max steps must be greater than 0") + if max_errors <= 0: + raise ValueError("Max errors must be greater than 0") + logger.info( + "STARTING SIMULATION: domain={} task={} agent={} user={}", domain, task.id, agent, user + ) + environment_constructor = registry.get_env_constructor(domain) + environment = environment_constructor() + init_kwargs = _initial_state_kwargs(task) + if initialize_environment and task.initial_state: + environment.set_state(**init_kwargs) + env_server = server_config = None + if use_socket_server and agent == "openclaw_agent": + env_server = EnvironmentSocketServer( + environment=environment, task=task, host="127.0.0.1", port=socket_port or 0 + ) + env_server.start() + server_config = env_server.get_client_config() + try: + agent_instance, environment, solo_mode = _build_agent( + task, + agent, + llm_agent, + llm_args_agent, + environment, + environment_constructor, + env_server, + server_config, + init_kwargs, + ) + from tau2.user.user_simulator import DummyUser + + user_constructor = registry.get_user_constructor(user) + if issubclass(user_constructor, DummyUser): + from tau2.agent.llm_agent import LLMSoloAgent + + assert isinstance(agent_instance, LLMSoloAgent), ( + "Dummy user can only be used with solo agent" + ) + try: + user_tools = environment.get_user_tools() + except Exception: + user_tools = None + user_instance = user_constructor( + tools=user_tools, + instructions=str(task.user_scenario), + llm=llm_user, + llm_args=llm_args_user, + ) + orch_kwargs = { + "domain": domain, + "agent": agent_instance, + "user": user_instance, + "environment": environment, + "task": task, + "max_steps": max_steps, + "max_errors": max_errors, + "seed": seed, + "solo_mode": solo_mode, + } + if "validate_communication" in inspect.signature(Orchestrator.__init__).parameters: + orch_kwargs["validate_communication"] = enforce_communication_protocol + simulation = Orchestrator(**orch_kwargs).run() + simulation.reward_info = evaluate_simulation_with_environment( + simulation=simulation, + task=task, + environment=environment, + evaluation_type=evaluation_type, + solo_mode=solo_mode, + ) + if save_simulation: + _save_simulation(domain, task, agent, simulation) + logger.info( + "FINISHED SIMULATION: domain={} task={} agent={} user={} reward={}", + domain, + task.id, + agent_instance.__class__.__name__, + user_instance.__class__.__name__, + simulation.reward_info.reward, + ) + return simulation, environment + finally: + if env_server: + logger.info("Stopping Environment Socket Server...") + env_server.stop() + + +def run_task_with_socket_server( + domain: str, + task: Task, + agent: str, + user: str, + llm_agent: str | None = None, + llm_args_agent: dict | None = None, + llm_user: str | None = None, + llm_args_user: dict | None = None, + max_steps: int = 100, + max_errors: int = 10, + evaluation_type: EvaluationType = EvaluationType.ALL, + seed: int | None = None, + enforce_communication_protocol: bool = False, + socket_port: int | None = None, +) -> tuple[SimulationRun, Environment]: + return _run_task_impl( + domain, + task, + agent, + user, + llm_agent, + llm_args_agent, + llm_user, + llm_args_user, + max_steps, + max_errors, + evaluation_type, + seed, + enforce_communication_protocol, + use_socket_server=True, + initialize_environment=True, + save_simulation=True, + socket_port=socket_port, + ) + + +def _build_cli_parser(description: str, max_steps: int): + import argparse + + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--domain", type=str, default="airline", choices=["airline", "retail"]) + parser.add_argument("--task-id", type=str) + parser.add_argument("--max-steps", type=int, default=max_steps) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument( + "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"] + ) + for name, default in ( + ("--agent", "openclaw_agent"), + ("--user", "user_simulator"), + ("--llm-agent", "anthropic/claude-3-5-sonnet-20241022"), + ("--llm-user", "openai/gpt-4.1-2025-04-14"), + ): + parser.add_argument(name, type=str, default=default) + return parser + + +def _configure_logging(log_level: str) -> None: + logger.remove() + logger.add(sys.stderr, level=log_level) + + +def _load_selected_task(domain: str, task_id: str | None) -> Task | None: + from tau2.run import load_tasks + + tasks = load_tasks(task_set_name=domain) + if not tasks: + logger.error("No tasks found for domain: {}", domain) + return None + task = next((item for item in tasks if item.id == task_id), tasks[0]) if task_id else tasks[0] + if task_id and task.id != task_id: + logger.error("Task {} not found", task_id) + return None + return task + + +def _log_simulation_summary( + simulation: SimulationRun, *, environment: Environment | None = None, debug: bool = False +) -> None: + logger.info("Task ID: {}", simulation.task_id) + logger.info("Steps taken: {}", len(simulation.messages)) + logger.info("Termination reason: {}", simulation.termination_reason) + logger.info("Duration: {:.2f}s", simulation.duration) + if simulation.reward_info: + logger.info("Reward: {}", simulation.reward_info.reward) + if simulation.reward_info.db_check: + logger.info("DB Match: {}", simulation.reward_info.db_check.db_match) + if simulation.reward_info.action_checks: + logger.info( + "Action Checks: {}/{} matched", + sum(check.action_match for check in simulation.reward_info.action_checks), + len(simulation.reward_info.action_checks), + ) + if simulation.reward_info.env_assertions: + logger.info( + "Environment Assertions: {}/{} met", + sum(check.met for check in simulation.reward_info.env_assertions), + len(simulation.reward_info.env_assertions), + ) + if simulation.agent_cost: + logger.info("Agent Cost: ${:.4f}", simulation.agent_cost) + if simulation.user_cost: + logger.info("User Cost: ${:.4f}", simulation.user_cost) + if environment is not None: + logger.info("Environment DB Hash: {}", environment.get_db_hash()) + if debug: + for idx, turn in enumerate(simulation.messages, 1): + logger.debug("Turn {}: role={} content={}", idx, turn.role, (turn.content or "")[:200]) + + +def main() -> int: + args = _build_cli_parser( + "Run TAU² task with OpenClaw agent using Socket Server", 100 + ).parse_args() + _configure_logging(args.log_level) + logger.info("Loading tasks for domain: {}", args.domain) + task = _load_selected_task(args.domain, args.task_id) + if task is None: + return 1 + logger.info("Selected task: {}", task.id) + try: + simulation, environment = run_task_with_socket_server( + args.domain, + task, + args.agent, + args.user, + args.llm_agent, + None, + args.llm_user, + None, + args.max_steps, + 10, + EvaluationType.ALL, + args.seed, + False, + ) + except Exception: + logger.exception("Error running task") + return 1 + _log_simulation_summary(simulation, environment=environment, debug=args.log_level == "DEBUG") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/experimental/inference_service/batchmode/tau2/tau2_env/__init__.py b/examples/experimental/inference_service/batchmode/tau2/tau2_env/__init__.py new file mode 100644 index 0000000000..46bd799ac0 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/tau2_env/__init__.py @@ -0,0 +1,9 @@ +from .environment_socket import EnvironmentSocketServer, create_openclaw_tool_script +from .evaluator import OpenClawEnvironmentEvaluator, evaluate_simulation_with_environment + +__all__ = [ + "evaluate_simulation_with_environment", + "OpenClawEnvironmentEvaluator", + "EnvironmentSocketServer", + "create_openclaw_tool_script", +] diff --git a/examples/experimental/inference_service/batchmode/tau2/tau2_env/environment_socket.py b/examples/experimental/inference_service/batchmode/tau2/tau2_env/environment_socket.py new file mode 100644 index 0000000000..ebc023de10 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/tau2_env/environment_socket.py @@ -0,0 +1,165 @@ +import json +import socket +import threading +from typing import Any + +from loguru import logger + +from tau2.data_model.tasks import Task +from tau2.environment.environment import Environment + + +def _recv_json(sock: socket.socket) -> dict[str, Any] | None: + chunks: list[bytes] = [] + while True: + data = sock.recv(4096) + if not data: + break + chunks.append(data) + if b"\n" in data: + break + return json.loads(b"".join(chunks).split(b"\n", 1)[0].decode("utf-8")) if chunks else None + + +def _serialize_result(result: Any) -> Any: + if result is None or isinstance(result, (str, int, float, bool)): + return result + if hasattr(result, "model_dump"): + return result.model_dump() + if hasattr(result, "dict"): + return result.dict() + if isinstance(result, list): + return [_serialize_result(item) for item in result] + return ( + {key: _serialize_result(value) for key, value in result.items()} + if isinstance(result, dict) + else result + ) + + +class EnvironmentSocketServer: + def __init__( + self, environment: Environment, task: Task = None, host: str = "127.0.0.1", port: int = 0 + ): + self.environment = environment + self.task = task + self.host = host + self.port = port + self.server_socket: socket.socket | None = None + self.running = False + self.thread: threading.Thread | None = None + + @staticmethod + def _send_json(sock: socket.socket, payload: dict[str, Any]) -> None: + sock.sendall((json.dumps(payload) + "\n").encode("utf-8")) + + def start(self) -> int: + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server_socket.bind((self.host, self.port)) + self.server_socket.listen(5) + self.port = self.server_socket.getsockname()[1] + self.running = True + self.thread = threading.Thread(target=self._serve, daemon=True) + self.thread.start() + logger.info("Environment Server listening on {}:{}", self.host, self.port) + return self.port + + def _serve(self) -> None: + while self.running and self.server_socket: + try: + client_socket, client_addr = self.server_socket.accept() + logger.debug("Client connected: {}", client_addr) + threading.Thread( + target=self._handle_client, args=(client_socket,), daemon=True + ).start() + except Exception as exc: + if self.running: + logger.error("Server error: {}", exc) + + def _handle_client(self, client_socket: socket.socket) -> None: + try: + while request := _recv_json(client_socket): + action = request.get("action") + if action == "call_tool": + response = self._call_tool( + request["tool_name"], + request.get("requestor", "assistant"), + request.get("arguments", {}), + ) + elif action == "get_state": + response = self._get_state() + else: + response = {"success": False, "error": f"Unknown action: {action}"} + self._send_json(client_socket, response) + except Exception as exc: + logger.error("Client handler error: {}", exc) + finally: + client_socket.close() + + def _call_tool( + self, tool_name: str, requestor: str, arguments: dict[str, Any] + ) -> dict[str, Any]: + try: + result = self.environment.make_tool_call( + tool_name=tool_name, requestor=requestor, **arguments + ) + self.environment.sync_tools() + return {"success": True, "result": _serialize_result(result)} + except Exception as exc: + logger.error("Tool call error: {}", exc) + return {"success": False, "error": str(exc)} + + def _get_state(self) -> dict[str, Any]: + try: + state = { + "success": True, + "domain": self.environment.domain_name, + "db_hash": self.environment.get_db_hash(), + } + if self.task: + state["task_id"] = self.task.id + return state + except Exception as exc: + return {"success": False, "error": str(exc)} + + def stop(self) -> None: + self.running = False + if self.server_socket: + self.server_socket.close() + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=5) + logger.info("Environment Server stopped") + + def get_client_config(self) -> dict[str, Any]: + return { + "host": self.host, + "port": self.port, + **({"task_id": self.task.id} if self.task else {}), + } + + +def create_openclaw_tool_script(tool_name: str, server_config: dict) -> str: + host = json.dumps(server_config["host"]) + tool = json.dumps(tool_name) + return f"""#!/usr/bin/env python +import json, socket, sys +def _recv_json(sock): + chunks=[] + while True: + data=sock.recv(4096) + if not data: break + chunks.append(data) + if b"\\n" in data: break + return json.loads(b"".join(chunks).split(b"\\n",1)[0].decode("utf-8")) if chunks else None +try: + args=json.loads(sys.argv[1]) if len(sys.argv)>1 else {{}} + with socket.socket(socket.AF_INET,socket.SOCK_STREAM) as sock: + sock.connect(({host}, {server_config["port"]})) + sock.sendall((json.dumps({{"action":"call_tool","tool_name":{tool},"requestor":"assistant","arguments":args}})+"\\n").encode("utf-8")) + response=_recv_json(sock) or {{"success": False, "error": "Empty response"}} + if not response.get("success"): raise Exception(response.get("error", "Tool call failed")) + print(json.dumps({{"success": True, "result": response["result"]}})) +except Exception as exc: + print(json.dumps({{"success": False, "error": str(exc)}})) +""" diff --git a/examples/experimental/inference_service/batchmode/tau2/tau2_env/evaluator.py b/examples/experimental/inference_service/batchmode/tau2/tau2_env/evaluator.py new file mode 100644 index 0000000000..cb05dc654e --- /dev/null +++ b/examples/experimental/inference_service/batchmode/tau2/tau2_env/evaluator.py @@ -0,0 +1,172 @@ +from loguru import logger + +from tau2.data_model.simulation import ( + DBCheck, + EnvAssertionCheck, + RewardInfo, + SimulationRun, + TerminationReason, +) +from tau2.data_model.tasks import RewardType, Task +from tau2.environment.environment import Environment +from tau2.evaluator.evaluator import EvaluationType +from tau2.evaluator.evaluator_action import ActionEvaluator +from tau2.evaluator.evaluator_communicate import CommunicateEvaluator +from tau2.evaluator.evaluator_nl_assertions import NLAssertionsEvaluator +from tau2.registry import registry + +_STOP_REASONS = {TerminationReason.AGENT_STOP, TerminationReason.USER_STOP} + + +def _note_reward(note: str) -> RewardInfo: + return RewardInfo(reward=1.0, reward_basis=None, info={"note": note}) + + +def _initial_state_kwargs(task: Task) -> dict: + state = task.initial_state + return { + "initialization_data": getattr(state, "initialization_data", None), + "initialization_actions": getattr(state, "initialization_actions", None), + "message_history": getattr(state, "message_history", None) or [], + } + + +def _merge_reward_breakdowns(*infos: RewardInfo | None) -> dict: + return { + key: value + for info in infos + if info and info.reward_breakdown + for key, value in info.reward_breakdown.items() + } + + +class OpenClawEnvironmentEvaluator: + @classmethod + def calculate_reward( + cls, environment: Environment, task: Task, full_trajectory: list, solo_mode: bool = False + ) -> RewardInfo: + _ = full_trajectory, solo_mode + criteria = task.evaluation_criteria + if criteria is None: + return _note_reward("No evaluation criteria") + if not criteria.actions and not criteria.env_assertions: + return RewardInfo( + reward=1.0, + db_check=DBCheck(db_match=True, db_reward=1.0), + reward_basis=criteria.reward_basis, + info={"note": "No expected actions or env assertions"}, + ) + gold_environment = registry.get_env_constructor(environment.domain_name)() + gold_environment.set_state(**_initial_state_kwargs(task)) + for action in criteria.actions or []: + try: + gold_environment.make_tool_call( + tool_name=action.name, requestor=action.requestor, **action.arguments + ) + except Exception as exc: + logger.warning( + "Error in golden action {}({}): {}", action.name, action.arguments, exc + ) + db_reward = float( + gold_environment.get_db_hash() == environment.get_db_hash() + and gold_environment.get_user_db_hash() == environment.get_user_db_hash() + ) + env_checks: list[EnvAssertionCheck] = [] + env_reward = 1.0 + for assertion in criteria.env_assertions or []: + met = environment.run_env_assertion(assertion, raise_assertion_error=False) + check = EnvAssertionCheck(env_assertion=assertion, met=met, reward=float(met)) + env_checks.append(check) + env_reward *= check.reward + reward = 1.0 + reward_breakdown = {} + if RewardType.DB in criteria.reward_basis: + reward_breakdown[RewardType.DB] = db_reward + reward *= db_reward + if RewardType.ENV_ASSERTION in criteria.reward_basis: + reward_breakdown[RewardType.ENV_ASSERTION] = env_reward + reward *= env_reward + return RewardInfo( + reward=reward, + db_check=DBCheck(db_match=bool(db_reward), db_reward=db_reward), + env_assertions=env_checks, + reward_basis=criteria.reward_basis, + reward_breakdown=reward_breakdown, + ) + + +def evaluate_simulation_with_environment( + simulation: SimulationRun, + task: Task, + environment: Environment, + evaluation_type: EvaluationType, + solo_mode: bool = False, +) -> RewardInfo: + if simulation.termination_reason not in _STOP_REASONS: + return RewardInfo( + reward=0.0, + reward_basis=None, + info={ + "note": f"Simulation terminated prematurely. Termination reason: {simulation.termination_reason.value}" + }, + ) + criteria = task.evaluation_criteria + if criteria is None: + return _note_reward("No evaluation criteria") + if evaluation_type == EvaluationType.ENV: + return OpenClawEnvironmentEvaluator.calculate_reward( + environment, task, simulation.messages, solo_mode + ) + if evaluation_type == EvaluationType.NL_ASSERTIONS: + return NLAssertionsEvaluator.calculate_reward( + task=task, full_trajectory=simulation.messages + ) + if evaluation_type == EvaluationType.COMMUNICATE: + return CommunicateEvaluator.calculate_reward(task=task, full_trajectory=simulation.messages) + if evaluation_type == EvaluationType.ACTION: + return ActionEvaluator.calculate_reward(task=task, full_trajectory=simulation.messages) + if evaluation_type not in {EvaluationType.ALL, EvaluationType.ALL_WITH_NL_ASSERTIONS}: + raise ValueError(f"Unknown evaluation type: {evaluation_type}") + env_info = OpenClawEnvironmentEvaluator.calculate_reward( + environment, task, simulation.messages, solo_mode + ) + action_info = ActionEvaluator.calculate_reward(task=task, full_trajectory=simulation.messages) + communicate_info = CommunicateEvaluator.calculate_reward( + task=task, full_trajectory=simulation.messages + ) + nl_info = ( + NLAssertionsEvaluator.calculate_reward(task=task, full_trajectory=simulation.messages) + if evaluation_type == EvaluationType.ALL_WITH_NL_ASSERTIONS + else None + ) + reward = 1.0 + reward_basis = set(criteria.reward_basis) + for basis_group, info in ( + ({RewardType.DB, RewardType.ENV_ASSERTION}, env_info), + ({RewardType.ACTION}, action_info), + ({RewardType.COMMUNICATE}, communicate_info), + ): + if reward_basis & basis_group: + reward *= info.reward + if RewardType.NL_ASSERTION in reward_basis: + if nl_info is None: + raise ValueError( + "NL assertions are part of the reward basis, but they are not being evaluated." + ) + reward *= nl_info.reward + return RewardInfo( + reward=reward, + db_check=env_info.db_check, + env_assertions=env_info.env_assertions, + action_checks=action_info.action_checks, + nl_assertions=nl_info.nl_assertions if nl_info else None, + communicate_checks=communicate_info.communicate_checks, + reward_basis=criteria.reward_basis, + reward_breakdown=_merge_reward_breakdowns(env_info, action_info, communicate_info, nl_info), + info={ + "env": env_info.info, + "nl": nl_info.info if nl_info else None, + "communicate": communicate_info.info, + "action": action_info.info, + }, + ) diff --git a/examples/experimental/inference_service/batchmode/worker.py b/examples/experimental/inference_service/batchmode/worker.py new file mode 100755 index 0000000000..4bb30639d8 --- /dev/null +++ b/examples/experimental/inference_service/batchmode/worker.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Single TAU²-Bench worker — runs one task with OpenClaw agent via socket server. + +Usage: + python run_single_worker.py \ + --domain retail \ + --task-index 0 \ + --agent-endpoint http://127.0.0.1:30000/v1 \ + --user-endpoint http://:30001/v1 \ + --model Qwen3-235B-A22B-Instruct-2507 \ + --output-dir /tmp/results + +Output: writes /task_.json with trajectory + reward. +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Run single TAU²-Bench task with OpenClaw" + ) + parser.add_argument("--domain", type=str, default="retail") + parser.add_argument("--task-index", type=int, required=True) + parser.add_argument("--agent-endpoint", type=str, required=True) + parser.add_argument("--user-endpoint", type=str, required=True) + parser.add_argument("--model", type=str, default="Qwen3-235B-A22B-Instruct-2507") + parser.add_argument("--user-llm", type=str, default="") + parser.add_argument("--max-steps", type=int, default=200) + parser.add_argument("--max-errors", type=int, default=10) + parser.add_argument("--output-dir", type=str, default="results") + parser.add_argument("--worker-id", type=int, default=0) + parser.add_argument("--seed", type=int, default=None) + parser.add_argument( + "--user-llm-args", + type=str, + default="", + help="JSON string of extra llm_args for user simulator, e.g. '{\"temperature\":0.0}'", + ) + args = parser.parse_args() + + os.environ["OPENCLAW_CLI_COMMAND"] = "openclaw" + os.environ["OPENCLAW_API_BASE"] = args.agent_endpoint + os.environ["OPENCLAW_API_KEY"] = os.environ.get("OPENCLAW_API_KEY", "dummy") + os.environ["OPENCLAW_MODEL"] = args.model + os.environ["OPENAI_API_BASE"] = args.user_endpoint + os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "dummy") + + real_stdout = sys.stdout + sys.stdout = sys.stderr + + import litellm + + litellm.drop_params = True + litellm.suppress_debug_info = True + + src_path = str(Path(__file__).resolve().parent.parent / "src") + if src_path not in sys.path: + sys.path.insert(0, src_path) + + tau2_src = os.environ.get("TAU2_SRC_PATH", "${TAU2_DIR}/src") + if os.path.isdir(tau2_src) and tau2_src not in sys.path: + sys.path.insert(0, tau2_src) + + import subprocess + + from openclaw_tau2 import run_task_with_socket_server + from tau2.registry import registry + from tau2.run import load_tasks + + openclaw_version = "unknown" + try: + r = subprocess.run( + ["openclaw", "--version"], capture_output=True, text=True, timeout=5 + ) + openclaw_version = ( + r.stdout.strip() if r.returncode == 0 else f"exit={r.returncode}" + ) + except Exception as e: + openclaw_version = f"not found: {e}" + + agent_constructor = registry.get_agent_constructor("openclaw_agent") + agent_class_name = f"{agent_constructor.__module__}.{agent_constructor.__name__}" + + print("=" * 72, file=sys.stderr) + print(" AGENT VERIFICATION", file=sys.stderr) + print(" Agent type: openclaw_agent", file=sys.stderr) + print(f" Agent class: {agent_class_name}", file=sys.stderr) + print(f" OpenClaw CLI: {openclaw_version}", file=sys.stderr) + print(f" Agent endpoint: {args.agent_endpoint}", file=sys.stderr) + print(f" Agent model: {args.model}", file=sys.stderr) + print(f" HOME: {os.environ.get('HOME', '?')}", file=sys.stderr) + print(f" Worker ID: {args.worker_id}", file=sys.stderr) + print("=" * 72, file=sys.stderr) + + tasks = load_tasks(task_set_name=args.domain) + if args.task_index >= len(tasks): + print( + json.dumps({"error": f"task_index {args.task_index} >= {len(tasks)} tasks"}) + ) + sys.exit(1) + + task = tasks[args.task_index] + user_llm = args.user_llm or f"openai/{args.model}" + + llm_args_user = None + if args.user_llm_args: + try: + llm_args_user = json.loads(args.user_llm_args) + except json.JSONDecodeError as e: + print(f"WARNING: Failed to parse --user-llm-args: {e}", file=sys.stderr) + + tag = f"[worker-{args.worker_id}][task-{task.id}]" + print(f"{tag} Starting: domain={args.domain} task={task.id}", file=sys.stderr) + + start = time.time() + try: + simulation, environment = run_task_with_socket_server( + domain=args.domain, + task=task, + agent="openclaw_agent", + user="user_simulator", + llm_user=user_llm, + llm_args_user=llm_args_user, + max_steps=args.max_steps, + max_errors=args.max_errors, + seed=args.seed, + socket_port=None, + ) + + duration = time.time() - start + reward = simulation.reward_info.reward if simulation.reward_info else 0.0 + + messages = [] + for msg in simulation.messages: + messages.append( + { + "role": msg.role + if hasattr(msg, "role") + else str(type(msg).__name__), + "content": msg.content if msg.content else "", + } + ) + + reward_detail = {} + if simulation.reward_info: + reward_detail["reward"] = simulation.reward_info.reward + if simulation.reward_info.db_check: + reward_detail["db_match"] = simulation.reward_info.db_check.db_match + if simulation.reward_info.action_checks: + matched = sum( + 1 for ac in simulation.reward_info.action_checks if ac.action_match + ) + reward_detail["action_checks"] = ( + f"{matched}/{len(simulation.reward_info.action_checks)}" + ) + + result = { + "task_id": str(task.id), + "task_index": args.task_index, + "domain": args.domain, + "worker_id": args.worker_id, + "model": args.model, + "reward": reward, + "reward_detail": reward_detail, + "num_steps": len(simulation.messages), + "termination_reason": simulation.termination_reason.value, + "duration_seconds": round(duration, 1), + "messages": messages, + } + + status = "✓" if reward > 0 else "✗" + print( + f"{tag} {status} reward={reward:.2f} steps={len(simulation.messages)} dur={duration:.1f}s term={simulation.termination_reason.value}", + file=sys.stderr, + ) + + except Exception as e: + import traceback + + duration = time.time() - start + error_type = type(e).__name__ + result = { + "task_id": str(task.id), + "task_index": args.task_index, + "domain": args.domain, + "worker_id": args.worker_id, + "model": args.model, + "reward": 0.0, + "num_steps": 0, + "termination_reason": "error", + "error_type": error_type, + "duration_seconds": round(duration, 1), + "error": traceback.format_exc(), + } + print(f"{tag} ✗ ERROR ({error_type}) dur={duration:.1f}s: {e}", file=sys.stderr) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + output_file = output_dir / f"task_{task.id}.json" + output_file.write_text(json.dumps(result, indent=2, ensure_ascii=False)) + + real_stdout.write( + json.dumps( + { + "task_id": str(task.id), + "reward": result["reward"], + "duration": result["duration_seconds"], + } + ) + + "\n" + ) + real_stdout.flush() + + +if __name__ == "__main__": + main() diff --git a/tests/test_openai_client_normalize.py b/tests/test_openai_client_normalize.py new file mode 100644 index 0000000000..5848854048 --- /dev/null +++ b/tests/test_openai_client_normalize.py @@ -0,0 +1,207 @@ +"""Tests for _normalize_messages_for_chat_template in ArealOpenAI client. + +Covers two normalizations that align ArealOpenAI with SGLang's native +/v1/chat/completions preprocessing: +1. Content flattening: list content → string for templates that expect string +2. tool_calls arguments: JSON string → dict for Jinja2 tojson sort_keys alignment +""" + +import json +from typing import Any + + +def _normalize_messages_for_chat_template(messages: list[dict[str, Any]]) -> None: + """Copied from areal.experimental.openai.client to avoid heavy import chain.""" + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if content is not None and not isinstance(content, str): + if not isinstance(content, list): + content = list(content) + parts = [] + for part in content: + if not isinstance(part, dict): + part = ( + dict(part) + if hasattr(part, "items") + else {"type": "text", "text": str(part)} + ) + if "text" in part and "type" not in part: + part["type"] = "text" + parts.append(part) + if len(parts) == 1 and parts[0].get("type") == "text": + msg["content"] = parts[0]["text"] + else: + msg["content"] = parts + if msg.get("role") == "assistant" and isinstance(msg.get("tool_calls"), list): + for tool_call in msg["tool_calls"]: + func = tool_call.get("function", tool_call) + if isinstance(func.get("arguments"), str): + try: + func["arguments"] = json.loads(func["arguments"]) + except (json.JSONDecodeError, TypeError): + pass + + +class TestContentFlattening: + def test_single_text_part_flattened_to_string(self): + msgs = [{"role": "user", "content": [{"type": "text", "text": "hello"}]}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] == "hello" + + def test_string_content_unchanged(self): + msgs = [{"role": "user", "content": "hello"}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] == "hello" + + def test_none_content_unchanged(self): + msgs = [{"role": "assistant", "content": None}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] is None + + def test_multi_part_content_kept_as_list(self): + parts = [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "http://example.com/img.png"}}, + ] + msgs = [{"role": "user", "content": parts}] + _normalize_messages_for_chat_template(msgs) + assert isinstance(msgs[0]["content"], list) + assert len(msgs[0]["content"]) == 2 + + def test_part_missing_type_field_gets_text_type(self): + msgs = [{"role": "user", "content": [{"text": "hello"}]}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] == "hello" + + def test_iterator_content_materialized(self): + def content_iter(): + yield {"type": "text", "text": "from iterator"} + + msgs = [{"role": "user", "content": content_iter()}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] == "from iterator" + + def test_non_dict_part_converted(self): + msgs = [{"role": "user", "content": [42]}] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["content"] == "42" + + +class TestToolCallsArgumentsParsing: + def test_string_arguments_parsed_to_dict(self): + msgs = [ + { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "search", + "arguments": '{"query": "flights", "limit": 5}', + } + } + ], + } + ] + _normalize_messages_for_chat_template(msgs) + args = msgs[0]["tool_calls"][0]["function"]["arguments"] + assert isinstance(args, dict) + assert args == {"query": "flights", "limit": 5} + + def test_dict_arguments_unchanged(self): + msgs = [ + { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "search", + "arguments": {"query": "flights"}, + } + } + ], + } + ] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["tool_calls"][0]["function"]["arguments"] == {"query": "flights"} + + def test_invalid_json_arguments_kept_as_string(self): + msgs = [ + { + "role": "assistant", + "tool_calls": [ + {"function": {"name": "fn", "arguments": "not valid json"}} + ], + } + ] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["tool_calls"][0]["function"]["arguments"] == "not valid json" + + def test_user_message_tool_calls_ignored(self): + msgs = [ + { + "role": "user", + "tool_calls": [{"function": {"name": "fn", "arguments": '{"a": 1}'}}], + } + ] + _normalize_messages_for_chat_template(msgs) + assert msgs[0]["tool_calls"][0]["function"]["arguments"] == '{"a": 1}' + + def test_jinja2_tojson_key_ordering_aligned(self): + """After parsing, Jinja2 tojson(sort_keys=True) produces alphabetically + sorted keys, matching SGLang's native /v1/chat/completions path.""" + import jinja2 + + args_str = '{"origin": "LAX", "destination": "SFO", "date": "2025-01-15"}' + msgs = [ + { + "role": "assistant", + "tool_calls": [{"function": {"name": "search", "arguments": args_str}}], + } + ] + _normalize_messages_for_chat_template(msgs) + args_dict = msgs[0]["tool_calls"][0]["function"]["arguments"] + + env = jinja2.Environment() + template = env.from_string("{{ val | tojson }}") + rendered = template.render(val=args_dict) + + assert '"date"' in rendered + assert rendered.index('"date"') < rendered.index('"destination"') + assert rendered.index('"destination"') < rendered.index('"origin"') + + +class TestMixedMessages: + def test_full_conversation_normalized(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": [{"type": "text", "text": "Book a flight"}]}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "function": { + "name": "search_flights", + "arguments": '{"origin": "LAX", "dest": "SFO"}', + } + } + ], + }, + {"role": "tool", "content": '{"flights": []}'}, + ] + _normalize_messages_for_chat_template(msgs) + + assert msgs[0]["content"] == "You are helpful." + assert msgs[1]["content"] == "Book a flight" + assert isinstance(msgs[2]["tool_calls"][0]["function"]["arguments"], dict) + assert msgs[3]["content"] == '{"flights": []}' + + def test_empty_messages_no_error(self): + _normalize_messages_for_chat_template([]) + + def test_non_dict_message_skipped(self): + msgs = ["not a dict", {"role": "user", "content": "hello"}] + _normalize_messages_for_chat_template(msgs) + assert msgs[1]["content"] == "hello"