Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/envs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ Executes Python code in a sandboxed environment. Demonstrates:

See: [`coding_env/README.md`](coding_env/README.md)

### Connect4 Environment
Location: `src/envs/connect4_env/`

Wraps the `gym-connect4` implementation to provide a turnkey board-game benchmark that follows the OpenEnv API, including typed models, HTTP client, and Docker image.

See: [`connect4_env/README.md`](connect4_env/README.md)

## Best Practices

### 1. Type Safety
Expand Down
38 changes: 38 additions & 0 deletions src/envs/connect4_env/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Connect4 Environment

This environment wraps the [`gym-connect4`](https://github.com/Danielhp95/gym-connect4) implementation inside OpenEnv. It exposes a turn-based 6x7 Connect Four board where the agent plays as player `+1` against the built-in opponent logic supplied by the Gym environment.

## Action, Observation, State

| Type | Fields | Description |
| --- | --- | --- |
| `Connect4Action` | `column: int` | 0-based column where the agent drops a disc. |
| `Connect4Observation` | `board: list[list[int]]`<br>`legal_actions: list[int]`<br>`current_player: int`<br>`last_move: Optional[int]`<br>`info: dict` | Board uses `1` for the agent, `-1` for the opponent, `0` for empty. Legal actions are the playable columns. When `done=True`, `legal_actions` is empty. Any metadata from Gym is forwarded through `info`. |
| `Connect4State` | `episode_id: str`<br>`step_count: int`<br>`rows: int`<br>`cols: int` | Mirrors the generic OpenEnv state and records the board geometry. |

Rewards from Gym can be scalars or a 2-element vector. The server always scalarizes them into an agent-centric `float` (`r_agent - r_opponent` when two values are supplied).

## Running the server

```bash
uvicorn envs.connect4_env.server.app:app --host 0.0.0.0 --port 8000
```

Set `GYM_CONNECT4_ID` if you need a custom Gym registration ID (default `Connect4-v0`).

## Client usage

```python
from envs.connect4_env import Connect4Env, Connect4Action

client = Connect4Env(base_url="http://localhost:8000")

result = client.reset()
print(result.observation.board)

while not result.done:
action = Connect4Action(column=result.observation.legal_actions[0])
result = client.step(action)

print("Episode reward:", result.reward)
```
13 changes: 13 additions & 0 deletions src/envs/connect4_env/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Connect4 OpenEnv package exports."""

from .client import Connect4Env
from .models import Connect4Action, Connect4Observation, Connect4State
from .server.connect4_environment import Connect4Environment

__all__ = (
"Connect4Action",
"Connect4Observation",
"Connect4State",
"Connect4Env",
"Connect4Environment",
)
42 changes: 42 additions & 0 deletions src/envs/connect4_env/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""HTTP client for the Connect4 OpenEnv environment."""

from __future__ import annotations

from typing import Any, Dict

from core.client_types import StepResult
from core.http_env_client import HTTPEnvClient

from .models import Connect4Action, Connect4Observation, Connect4State


class Connect4Env(HTTPEnvClient[Connect4Action, Connect4Observation]):
"""Thin HTTP client used by agents to interact with the Connect4 server."""

def _step_payload(self, action: Connect4Action) -> Dict[str, Any]:
return {"column": action.column, "metadata": action.metadata}

def _parse_result(self, payload: Dict[str, Any]) -> StepResult[Connect4Observation]:
obs_data = payload.get("observation", {})
observation = Connect4Observation(
board=obs_data.get("board", []),
legal_actions=obs_data.get("legal_actions", []),
current_player=obs_data.get("current_player", 1),
last_move=obs_data.get("last_move"),
info=obs_data.get("info", {}),
done=payload.get("done", False),
reward=payload.get("reward"),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)

def _parse_state(self, payload: Dict[str, Any]) -> Connect4State:
return Connect4State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
rows=payload.get("rows", 6),
cols=payload.get("cols", 7),
)
34 changes: 34 additions & 0 deletions src/envs/connect4_env/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Data models for the Connect4 OpenEnv environment."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from core.env_server.types import Action, Observation, State


@dataclass(kw_only=True)
class Connect4Action(Action):
"""Selects the column (0-indexed) where the agent wants to drop a disc."""

column: int


@dataclass(kw_only=True)
class Connect4Observation(Observation):
"""Observation returned after every step/reset."""

board: List[List[int]] # 6x7 grid with 1 (agent), -1 (opponent), 0 (empty)
legal_actions: List[int]
current_player: int
last_move: Optional[int] = None
info: Dict[str, Any] = field(default_factory=dict)


@dataclass
class Connect4State(State):
"""Track episode metadata plus board geometry for convenience."""

rows: int = 6
cols: int = 7
23 changes: 23 additions & 0 deletions src/envs/connect4_env/server/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Build on top of the shared OpenEnv base image
ARG BASE_IMAGE=openenv-base:latest
FROM ${BASE_IMAGE}

# Install git for pip VCS installs
RUN apt-get update && apt-get install -y --no-install-recommends git && \
rm -rf /var/lib/apt/lists/*

# Install environment-specific dependencies
RUN pip install --no-cache-dir "gym==0.25.2" "numpy<2.0" \
git+https://github.com/Danielhp95/gym-connect4

# Copy the framework core plus this environment
COPY src/core/ /app/src/core/
COPY src/envs/connect4_env/ /app/src/envs/connect4_env/
COPY src/envs/connect4_env/README.md /app/README.md

# Simple health check - the web UI reuses /health
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1

# Run the FastAPI server
CMD ["uvicorn", "envs.connect4_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
5 changes: 5 additions & 0 deletions src/envs/connect4_env/server/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Server package for the Connect4 OpenEnv environment."""

from .connect4_environment import Connect4Environment

__all__ = ("Connect4Environment",)
14 changes: 14 additions & 0 deletions src/envs/connect4_env/server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""FastAPI entrypoint for the Connect4 OpenEnv server."""

from core.env_server.http_server import create_app

from ..models import Connect4Action, Connect4Observation
from .connect4_environment import Connect4Environment

env = Connect4Environment()
app = create_app(env, Connect4Action, Connect4Observation, env_name="connect4_env")

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
204 changes: 204 additions & 0 deletions src/envs/connect4_env/server/connect4_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""Gym-based Connect4 environment wrapped for OpenEnv."""

from __future__ import annotations

import importlib
import os
from typing import Any, Dict, Tuple
from uuid import uuid4

import numpy as np

from core.env_server.interfaces import Environment

from ..models import Connect4Action, Connect4Observation, Connect4State

# Ensure the third-party Gym env registers itself if present.
try: # pragma: no cover - optional dependency is best-effort
importlib.import_module("gym_connect4")
except Exception: # noqa: BLE001
pass

try:
import gym
except ImportError as exc: # pragma: no cover
raise ImportError(
"The Connect4 environment requires gym>=0.25. "
"Install it inside your Docker image or development venv."
) from exc


def _scalarize_reward(reward: Any) -> float:
"""Map scalar, vector, or ndarray rewards into a single float."""
if isinstance(reward, (list, tuple, np.ndarray)):
arr = np.asarray(reward, dtype=float)
if arr.shape == (2,):
return float(arr[0] - arr[1])
return float(arr.sum())
return float(reward)


def _normalize_board(obs: Any) -> Tuple[np.ndarray, Dict[str, Any]]:
"""
Convert arbitrary Connect4 observations into a canonical 6x7 np.ndarray.

Supports: (obs, info) tuples, 2x6x7 one-hot planes, 6x7x2 one-hot tensors,
or per-cell vectors embedded in object arrays.
"""
info: Dict[str, Any] = {}
board = obs
if isinstance(obs, tuple) and len(obs) == 2:
board, info = obs

arr = np.array(board, dtype=object)

if arr.ndim == 2 and arr.dtype != object:
return arr.astype(int), info

if arr.ndim == 3 and arr.dtype != object and arr.shape[0] == 2:
return (arr[0].astype(int) - arr[1].astype(int)), info

if arr.ndim == 3 and arr.dtype != object and arr.shape[2] == 2:
return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info

if (
arr.ndim == 4
and arr.dtype != object
and arr.shape[0] >= 1
and arr.shape[1] == 3
):
# gym-connect4 returns a list of per-player 3-plane tensors with shape
# (players, channels=3, width, height). Convert the first player's view
# (agent perspective) into a signed board matrix.
player_view = arr[0] # shape (3, width, height)
pieces = player_view[1].astype(int) - player_view[2].astype(int)
# Convert to (rows, cols) with row zero on top.
return pieces.T, info

if arr.ndim == 2 and arr.dtype == object:
h, w = arr.shape
out = np.zeros((h, w), dtype=int)
for r in range(h):
for c in range(w):
val = np.asarray(arr[r, c], dtype=int).ravel()
if val.size == 2:
out[r, c] = int(val[0] - val[1])
elif val.size == 1:
out[r, c] = int(val[0])
return out, info

# Fallback: best effort for mismatched shapes
try: # pragma: no cover - defensive branch
if arr.ndim == 3 and arr.shape[0] == 2:
return (arr[0].astype(int) - arr[1].astype(int)), info
if arr.ndim == 3 and arr.shape[2] == 2:
return (arr[:, :, 0].astype(int) - arr[:, :, 1].astype(int)), info
except Exception: # noqa: BLE001
pass

return np.zeros((6, 7), dtype=int), info


def _legal_actions(board: np.ndarray) -> list[int]:
return [c for c in range(board.shape[1]) if board[0, c] == 0]


def _current_player(info: Dict[str, Any], board: np.ndarray) -> int:
try:
cp = int(info.get("current_player", 0))
if cp in (1, -1):
return cp
except Exception: # noqa: BLE001
pass

p1 = int((board == 1).sum())
p2 = int((board == -1).sum())
return 1 if p1 == p2 else -1


class Connect4Environment(Environment):
"""Wrap the gym-connect4 environment so it can be served over HTTP."""

def __init__(self, gym_id: str | None = None):
super().__init__()
self._gym_id = gym_id or os.getenv("GYM_CONNECT4_ID", "Connect4-v0")
self._env: gym.Env | None = None
self._state = Connect4State()

def _ensure_env(self) -> gym.Env:
if self._env is None:
self._env = gym.make(self._gym_id)
return self._env

def reset(self) -> Connect4Observation:
env = self._ensure_env()
raw_obs = env.reset()
board, info = _normalize_board(raw_obs)
rows, cols = board.shape
self._state = Connect4State(
episode_id=str(uuid4()),
step_count=0,
rows=rows,
cols=cols,
)

legal_actions = info.get("legal_actions") if info else None
if legal_actions is None:
legal_actions = _legal_actions(board)

return Connect4Observation(
board=board.tolist(),
legal_actions=list(legal_actions),
current_player=_current_player(info, board),
last_move=info.get("last_move"),
reward=0.0,
done=False,
info=info,
)

def step(self, action: Connect4Action) -> Connect4Observation: # type: ignore[override]
env = self._ensure_env()
result = env.step(int(action.column))

# Gym 0.25 returns 4-tuple, 0.26+ returns 5-tuple.
if isinstance(result, tuple) and len(result) == 5:
obs, reward, terminated, truncated, info = result
elif isinstance(result, tuple) and len(result) == 4:
obs, reward, done, info = result
terminated, truncated = bool(done), False
else: # pragma: no cover - defensive branch
raise RuntimeError(
f"Unexpected Gym step return type for Connect4: {type(result)}"
)

done = bool(terminated or truncated)
board, info2 = _normalize_board(obs)
merged_info: Dict[str, Any] = info or {}
merged_info.update(info2 or {})

self._state.step_count += 1

legal_actions = merged_info.get("legal_actions")
if done:
legal_actions = []
elif legal_actions is None:
legal_actions = _legal_actions(board)

return Connect4Observation(
board=board.tolist(),
legal_actions=list(legal_actions),
current_player=_current_player(merged_info, board),
last_move=merged_info.get("last_move"),
done=done,
reward=_scalarize_reward(reward),
info=merged_info,
)

@property
def state(self) -> Connect4State:
return self._state

def close(self) -> None:
if self._env is not None and hasattr(self._env, "close"):
self._env.close()
self._env = None
Loading