Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions executors/PPO/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.13
90 changes: 90 additions & 0 deletions executors/PPO/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
[project]
name = "hypha-ppo-executor"
version = "0.0.0"
license = "Apache-2.0"
requires-python = ">=3.12"
dependencies = [
"httpx>=0.28.1",
"safetensors>=0.7.0",
]

[project.optional-dependencies]
mps_cu128 = [
"torch>=2.9.0",
]
cpu = [
"torch>=2.9.0",
]
cu128 = [
"torch>=2.9.0",
]
cu130 = [
"torch>=2.9.0",
]
rocm64 = [
"torch>=2.9.0",
"pytorch-triton-rocm",
]

[tool.uv]
conflicts = [
[
{ extra = "mps_cu128" },
{ extra = "cpu" },
{ extra = "cu128" },
{ extra = "cu130" },
{ extra = "rocm64" },
],
]

[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cu128", extra = "cu128" },
{ index = "pytorch-cu130", extra = "cu130" },
{ index = "pytorch-rocm64", extra = "rocm64"},
]
pytorch-triton-rocm = [
{ index = "pytorch-rocm64", extra = "rocm64"},
]

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[[tool.uv.index]]
name = "pytorch-cu130"
url = "https://download.pytorch.org/whl/cu130"
explicit = true

[[tool.uv.index]]
name = "pytorch-rocm64"
url = "https://download.pytorch.org/whl/rocm6.4"
explicit = true

[dependency-groups]
dev = [
"ruff>=0.14.9",
]

[build-system]
requires = ["uv_build>=0.9.8,<0.10.0"]
build-backend = "uv_build"

[tool.uv.build-backend]
module-name = "hypha.gym_executor"

[tool.ruff]
line-length = 120

[tool.ruff.format]
docstring-code-format = true

[tool.ruff.lint]
select = ["E", "F", "W", "I", "N", "UP", "SIM", "ARG", "PL"]
Empty file.
72 changes: 72 additions & 0 deletions executors/PPO/src/hypha/ppo_executor/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from types import TracebackType
from typing import Any, override

import httpx


class Session(AbstractContextManager["Session", None]):
def __init__(self, socket_path: str) -> None:
transport = httpx.HTTPTransport(uds=socket_path)
self._client: httpx.Client = httpx.Client(transport=transport)

@override
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self._client.close()

def send_resource(self, resource: Any, path: str, timeout: float | None = None) -> None:
timeout_ms = int(timeout * 1000) if timeout is not None else None
req = {"resource": resource, "path": path, "timeout_ms": timeout_ms}
# We must allow the client to wait at least as long as the requested timeout.
# If timeout is None, wait forever.
_ = self._client.post("http://hypha/resources/send", json=req, timeout=timeout).raise_for_status()

def send_action(self, payload: Any) -> Any:
resp = self._client.post("http://hypha/action/update", json=payload, timeout=None).raise_for_status()
return resp.json()

def fetch(self, resource: Any) -> Any:
resp = self._client.post("http://hypha/resources/fetch", json=resource, timeout=None).raise_for_status()
return resp.json()

@contextmanager
def receive(self, resource: Any, path: str, timeout: float | None = None) -> Iterator["EventSource"]:
req = {"resource": resource, "path": path}
# Use a short connect timeout to fail fast if the local side is unresponsive,
# but respect the provided timeout for the total duration/read.
# If timeout is None, we still enforce a connect timeout.
timeout_config = httpx.Timeout(timeout, connect=5.0)
with self._client.stream(
"POST",
"http://hypha/resources/receive",
json=req,
headers={"Accept": "text/event-stream"},
timeout=timeout_config,
) as resp:
yield EventSource(resp)


class EventSource:
def __init__(self, response: httpx.Response) -> None:
self._response: httpx.Response = response

@property
def response(self) -> httpx.Response:
return self._response

def __iter__(self) -> Iterator[Any]:
for line in self._response.iter_lines():
fieldname, _, value = line.rstrip("\n").partition(":")

if fieldname == "data":
result = json.loads(value)

yield result
# Ignore other SSE fields (e.g., event:, id:, retry:)
Loading
Loading