diff --git a/.gitignore b/.gitignore index 0bd5fdd..07acc6c 100644 --- a/.gitignore +++ b/.gitignore @@ -24,8 +24,9 @@ wheels/ .venv/ venv/ -# UV +# uv and ruff .uv/ +.ruff_cache/ # IDE .claude/ @@ -87,4 +88,3 @@ temp/ # Docker compose overrides (generated per-worktree) docker-compose.override.yml - diff --git a/environments/_template/env.py b/environments/_template/env.py index 6ee79eb..7d4e49c 100644 --- a/environments/_template/env.py +++ b/environments/_template/env.py @@ -23,7 +23,7 @@ import structlog from kinitro.environments import get_environment -from kinitro.environments.registry import get_all_environment_ids +from kinitro.environments.registry import get_environments_by_family from kinitro.rl_interface import Action logger = structlog.get_logger() @@ -69,8 +69,8 @@ async def _call_miner( async def list_environments(self) -> list[str]: """List available environments in this family.""" - # TODO: Change "myenv/" to your environment family prefix - return [e for e in get_all_environment_ids() if e.startswith("myenv/")] + # TODO: Change "myenv" to your environment family prefix + return get_environments_by_family("myenv") async def evaluate( self, diff --git a/environments/genesis/env.py b/environments/genesis/env.py index 06bee94..d07678f 100644 --- a/environments/genesis/env.py +++ b/environments/genesis/env.py @@ -30,7 +30,7 @@ # Import from kinitro package (installed in container via PYTHONPATH) from kinitro.environments import get_environment from kinitro.environments.base import RoboticsEnvironment -from kinitro.environments.registry import get_all_environment_ids +from kinitro.environments.registry import get_environments_by_family from kinitro.rl_interface import Action logger = structlog.get_logger() @@ -101,7 +101,7 @@ async def _call_miner( async def list_environments(self) -> list[str]: """List available Genesis environments.""" - return [env_id for env_id in get_all_environment_ids() if env_id.startswith("genesis/")] + return get_environments_by_family("genesis") async def evaluate( self, diff --git a/environments/metaworld/env.py b/environments/metaworld/env.py index 8d3a122..03543f3 100644 --- a/environments/metaworld/env.py +++ b/environments/metaworld/env.py @@ -37,7 +37,7 @@ # Import from kinitro package (installed in container via PYTHONPATH) from kinitro.environments import get_environment -from kinitro.environments.registry import get_all_environment_ids +from kinitro.environments.registry import get_environments_by_family from kinitro.rl_interface import Action logger = structlog.get_logger() @@ -93,7 +93,7 @@ async def _call_miner( async def list_environments(self) -> list[str]: """List available MetaWorld environments.""" - return [env_id for env_id in get_all_environment_ids() if env_id.startswith("metaworld/")] + return get_environments_by_family("metaworld") async def evaluate( self, diff --git a/kinitro/api/routes/scores.py b/kinitro/api/routes/scores.py index 41df6b5..b374b92 100644 --- a/kinitro/api/routes/scores.py +++ b/kinitro/api/routes/scores.py @@ -6,7 +6,9 @@ from kinitro.api.deps import get_session, get_storage from kinitro.backend.models import ( EvaluationCycle, + EvaluationCycleORM, MinerScore, + MinerScoreORM, ScoresResponse, ) from kinitro.backend.storage import Storage @@ -14,19 +16,10 @@ router = APIRouter(prefix="/v1/scores", tags=["Scores"]) -@router.get("/latest", response_model=ScoresResponse) -async def get_latest_scores( - session: AsyncSession = Depends(get_session), - storage: Storage = Depends(get_storage), -): - """Get scores from the most recent completed evaluation cycle.""" - cycle = await storage.get_latest_cycle(session, completed_only=True) - if cycle is None: - raise HTTPException(status_code=404, detail="No completed evaluations yet") - - scores_orm = await storage.get_scores_for_cycle(session, cycle.id) - - # Convert to response format +def _build_scores_response( + cycle: EvaluationCycleORM, scores_orm: list[MinerScoreORM] +) -> ScoresResponse: + """Build a ScoresResponse from a cycle ORM object and its scores.""" scores = [ MinerScore( uid=s.uid, @@ -40,7 +33,6 @@ async def get_latest_scores( for s in scores_orm ] - # Build miner summary miner_summary: dict[int, dict[str, float]] = {} for s in scores_orm: if s.uid not in miner_summary: @@ -54,6 +46,20 @@ async def get_latest_scores( ) +@router.get("/latest", response_model=ScoresResponse) +async def get_latest_scores( + session: AsyncSession = Depends(get_session), + storage: Storage = Depends(get_storage), +): + """Get scores from the most recent completed evaluation cycle.""" + cycle = await storage.get_latest_cycle(session, completed_only=True) + if cycle is None: + raise HTTPException(status_code=404, detail="No completed evaluations yet") + + scores_orm = await storage.get_scores_for_cycle(session, cycle.id) + return _build_scores_response(cycle, scores_orm) + + @router.get("/{cycle_id}", response_model=ScoresResponse) async def get_scores_for_cycle( cycle_id: int, @@ -66,28 +72,4 @@ async def get_scores_for_cycle( raise HTTPException(status_code=404, detail=f"Cycle {cycle_id} not found") scores_orm = await storage.get_scores_for_cycle(session, cycle_id) - - scores = [ - MinerScore( - uid=s.uid, - hotkey=s.hotkey, - env_id=s.env_id, - success_rate=s.success_rate, - mean_reward=s.mean_reward, - episodes_completed=s.episodes_completed, - episodes_failed=s.episodes_failed, - ) - for s in scores_orm - ] - - miner_summary: dict[int, dict[str, float]] = {} - for s in scores_orm: - if s.uid not in miner_summary: - miner_summary[s.uid] = {} - miner_summary[s.uid][s.env_id] = s.success_rate - - return ScoresResponse( - cycle=EvaluationCycle.model_validate(cycle), - scores=scores, - miner_summary=miner_summary, - ) + return _build_scores_response(cycle, scores_orm) diff --git a/kinitro/api/routes/weights.py b/kinitro/api/routes/weights.py index 1e4d419..989024b 100644 --- a/kinitro/api/routes/weights.py +++ b/kinitro/api/routes/weights.py @@ -4,29 +4,21 @@ from sqlalchemy.ext.asyncio import AsyncSession from kinitro.api.deps import get_session, get_storage -from kinitro.backend.models import WeightsResponse, WeightsU16 +from kinitro.backend.models import ( + ComputedWeightsORM, + EvaluationCycleORM, + WeightsResponse, + WeightsU16, +) from kinitro.backend.storage import Storage router = APIRouter(prefix="/v1/weights", tags=["Weights"]) -@router.get("/latest", response_model=WeightsResponse) -async def get_latest_weights( - session: AsyncSession = Depends(get_session), - storage: Storage = Depends(get_storage), -): - """ - Get the most recently computed weights. - - These weights are ready to be submitted to the chain by validators. - """ - weights_orm = await storage.get_latest_weights(session) - if weights_orm is None: - raise HTTPException(status_code=404, detail="No weights available yet") - - # Get the associated cycle for metadata - cycle = await storage.get_cycle(session, weights_orm.cycle_id) - +def _build_weights_response( + weights_orm: ComputedWeightsORM, cycle: EvaluationCycleORM | None +) -> WeightsResponse: + """Build a WeightsResponse from a weights ORM object and its cycle.""" return WeightsResponse( cycle_id=weights_orm.cycle_id, block_number=weights_orm.block_number, @@ -44,12 +36,30 @@ async def get_latest_weights( ) +@router.get("/latest", response_model=WeightsResponse) +async def get_latest_weights( + session: AsyncSession = Depends(get_session), + storage: Storage = Depends(get_storage), +) -> WeightsResponse: + """ + Get the most recently computed weights. + + These weights are ready to be submitted to the chain by validators. + """ + weights_orm = await storage.get_latest_weights(session) + if weights_orm is None: + raise HTTPException(status_code=404, detail="No weights available yet") + + cycle = await storage.get_cycle(session, weights_orm.cycle_id) + return _build_weights_response(weights_orm, cycle) + + @router.get("/{block_number}", response_model=WeightsResponse) async def get_weights_for_block( block_number: int, session: AsyncSession = Depends(get_session), storage: Storage = Depends(get_storage), -): +) -> WeightsResponse: """Get weights computed at a specific block.""" weights_orm = await storage.get_weights_for_block(session, block_number) if weights_orm is None: @@ -59,19 +69,4 @@ async def get_weights_for_block( ) cycle = await storage.get_cycle(session, weights_orm.cycle_id) - - return WeightsResponse( - cycle_id=weights_orm.cycle_id, - block_number=weights_orm.block_number, - timestamp=weights_orm.created_at, - weights={int(k): float(v) for k, v in weights_orm.weights_json.items()}, - weights_u16=WeightsU16( - uids=weights_orm.weights_u16_json["uids"], - values=weights_orm.weights_u16_json["values"], - ), - metadata={ - "n_miners_evaluated": cycle.n_miners if cycle else None, - "n_environments": cycle.n_environments if cycle else None, - "evaluation_duration_seconds": cycle.duration_seconds if cycle else None, - }, - ) + return _build_weights_response(weights_orm, cycle) diff --git a/kinitro/config.py b/kinitro/config.py index 54cec06..7d39ef8 100644 --- a/kinitro/config.py +++ b/kinitro/config.py @@ -15,17 +15,9 @@ class NetworkConfig(BaseSettings): hotkey_name: str = Field(default="default", description="Hotkey name") -class ValidatorConfig(BaseSettings): +class ValidatorConfig(NetworkConfig): """Validator-specific configuration.""" - model_config = SettingsConfigDict(env_prefix="KINITRO_") - - # Network settings (inherited concept) - network: str = Field(default="finney") - netuid: int = Field(default=1) - wallet_name: str = Field(default="default") - hotkey_name: str = Field(default="default") - # Evaluation settings episodes_per_env: int = Field( default=50, description="Number of episodes per environment per evaluation cycle" @@ -55,16 +47,9 @@ class ValidatorConfig(BaseSettings): log_level: str = Field(default="INFO", description="Logging level") -class MinerConfig(BaseSettings): +class MinerConfig(NetworkConfig): """Miner-specific configuration.""" - model_config = SettingsConfigDict(env_prefix="KINITRO_") - - network: str = Field(default="finney") - netuid: int = Field(default=1) - wallet_name: str = Field(default="default") - hotkey_name: str = Field(default="default") - # Model settings huggingface_repo: str | None = Field(default=None, description="HuggingFace model repo") model_revision: str | None = Field(default=None, description="Model revision/commit SHA") diff --git a/kinitro/environments/__init__.py b/kinitro/environments/__init__.py index 8497d05..627b268 100644 --- a/kinitro/environments/__init__.py +++ b/kinitro/environments/__init__.py @@ -5,7 +5,7 @@ robotics simulation environments (MetaWorld, Genesis, DM Control, ManiSkill). """ -from kinitro.environments.base import EpisodeResult, RoboticsEnvironment, TaskConfig +from kinitro.environments.base import RoboticsEnvironment, TaskConfig from kinitro.environments.registry import ( ENVIRONMENTS, get_all_environment_ids, @@ -16,7 +16,6 @@ __all__ = [ "RoboticsEnvironment", "TaskConfig", - "EpisodeResult", "ENVIRONMENTS", "get_environment", "get_all_environment_ids", diff --git a/kinitro/environments/base.py b/kinitro/environments/base.py index e2500c4..7f7a7d8 100644 --- a/kinitro/environments/base.py +++ b/kinitro/environments/base.py @@ -45,23 +45,6 @@ def to_dict(self) -> dict[str, Any]: } -@dataclass -class EpisodeResult: - """Result of a single episode evaluation.""" - - success: bool - total_reward: float - timesteps: int - info: dict[str, Any] = field(default_factory=dict) - - @property - def efficiency(self) -> float: - """Reward per timestep (higher is better).""" - if self.timesteps == 0: - return 0.0 - return self.total_reward / self.timesteps - - class RoboticsEnvironment(ABC): """ Abstract base class for all robotics environments. diff --git a/kinitro/environments/metaworld_env.py b/kinitro/environments/metaworld_env.py index f0af0a8..d76618c 100644 --- a/kinitro/environments/metaworld_env.py +++ b/kinitro/environments/metaworld_env.py @@ -206,12 +206,6 @@ def action_shape(self) -> tuple[int, ...]: return (8,) return (7,) - @property - def action_bounds(self) -> tuple[np.ndarray, np.ndarray]: - low = np.full(self.action_shape, -1.0, dtype=np.float32) - high = np.full(self.action_shape, 1.0, dtype=np.float32) - return (low, high) - def _warn_once(self, key: str, message: str, **kwargs: Any) -> None: if key in self._warned_keys: return @@ -275,26 +269,27 @@ def _extract_proprioceptive_obs(self, full_obs: np.ndarray) -> tuple[np.ndarray, proprio = full_obs[0:4].astype(np.float32) return proprio[0:3], float(proprio[3]) + def _sync_camera_state(self, cam_env: Any) -> None: + """Copy physics state from main env to a camera env and run forward kinematics.""" + if hasattr(self._env, "unwrapped") and hasattr(cam_env, "unwrapped"): + env = cast(Any, self._env) + main_data = env.unwrapped.data + cam_data = cam_env.unwrapped.data + + cam_data.qpos[:] = main_data.qpos[:] + cam_data.qvel[:] = main_data.qvel[:] + + mj_forward = getattr(mujoco, "mj_forward", None) + if callable(mj_forward): + mj_forward(cam_env.unwrapped.model, cam_data) + def _get_camera_images(self) -> dict[str, np.ndarray]: """Render images from all configured cameras.""" images = {} for cam_name, cam_env in self._camera_envs.items(): try: - # Copy the physics state from main env to camera env - if hasattr(self._env, "unwrapped") and hasattr(cam_env, "unwrapped"): - env = cast(Any, self._env) - main_data = env.unwrapped.data - cam_data = cam_env.unwrapped.data - - # Copy qpos and qvel - cam_data.qpos[:] = main_data.qpos[:] - cam_data.qvel[:] = main_data.qvel[:] - - # Forward kinematics to update derived quantities - mj_forward = getattr(mujoco, "mj_forward", None) - if callable(mj_forward): - mj_forward(cam_env.unwrapped.model, cam_data) + self._sync_camera_state(cam_env) # Render img = cam_env.render() @@ -635,17 +630,8 @@ def render(self, camera_name: str = "corner") -> np.ndarray | None: """ if camera_name in self._camera_envs: try: - # Sync state and render cam_env = self._camera_envs[camera_name] - if hasattr(self._env, "unwrapped") and hasattr(cam_env, "unwrapped"): - env = cast(Any, self._env) - main_data = env.unwrapped.data - cam_data = cam_env.unwrapped.data - cam_data.qpos[:] = main_data.qpos[:] - cam_data.qvel[:] = main_data.qvel[:] - mj_forward = getattr(mujoco, "mj_forward", None) - if callable(mj_forward): - mj_forward(cam_env.unwrapped.model, cam_data) + self._sync_camera_state(cam_env) return cam_env.render() except Exception as e: logger.debug("metaworld_render_failed", camera_name=camera_name, error=str(e)) diff --git a/kinitro/environments/registry.py b/kinitro/environments/registry.py index e5cca69..08e3470 100644 --- a/kinitro/environments/registry.py +++ b/kinitro/environments/registry.py @@ -173,21 +173,3 @@ def get_available_families() -> list[str]: def get_family_metadata(family: str) -> dict[str, str] | None: """Get display metadata for a family (name, description) from metadata.json.""" return _get_family_metadata_cache().get(family) - - -def is_family_available(family: str) -> bool: - """Check if an environment family is available.""" - return family in _get_family_metadata_cache() - - -def register_environment(env_id: str, factory: EnvFactory) -> None: - """ - Register a new environment. - - Args: - env_id: Unique environment identifier - factory: Callable that returns a RoboticsEnvironment - """ - if env_id in ENVIRONMENTS: - raise ValueError(f"Environment {env_id} is already registered") - ENVIRONMENTS[env_id] = factory diff --git a/kinitro/executor/env_loader.py b/kinitro/executor/env_loader.py new file mode 100644 index 0000000..80bec47 --- /dev/null +++ b/kinitro/executor/env_loader.py @@ -0,0 +1,162 @@ +"""Shared utilities for loading affinetes environments and running evaluations.""" + +import asyncio +import subprocess +from typing import Any + +import affinetes as af_env +import docker.types +import structlog + +from kinitro.backend.models import Task, TaskResult + +logger = structlog.get_logger() + + +def build_load_kwargs( + image: str, + eval_mode: str, + mem_limit: str, + executor_id: str, + family: str, + hosts: list[str], + eval_timeout: int, + gpu_enabled: bool = False, +) -> dict[str, Any]: + """Build kwargs dict for af_env.load_env(). + + Args: + image: Docker image tag for the environment. + eval_mode: 'docker' or 'basilica'. + mem_limit: Memory limit string (e.g. '8g'). + executor_id: Unique executor identifier. + family: Environment family name (e.g. 'metaworld'). + hosts: Docker hosts list. + eval_timeout: Evaluation timeout in seconds. + gpu_enabled: Whether to enable GPU passthrough. + + Returns: + Dict of keyword arguments for af_env.load_env(). + """ + load_kwargs: dict[str, Any] = { + "image": image, + "mode": eval_mode, + "mem_limit": mem_limit, + "pull": True, + } + + if eval_mode == "docker": + load_kwargs.update( + { + "hosts": hosts, + "container_name": f"kinitro-eval-{executor_id}-{family}", + "force_recreate": True, + } + ) + if gpu_enabled: + load_kwargs["device_requests"] = [ + docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]]) + ] + elif eval_mode == "basilica": + load_kwargs.update( + { + "cpu_limit": "2000m", + "ttl_buffer": eval_timeout + 60, + } + ) + + return load_kwargs + + +async def load_and_warmup_env(family: str, image: str, load_kwargs: dict[str, Any]) -> Any: + """Load an affinetes environment and perform a warmup call. + + Args: + family: Environment family name (for logging). + image: Docker image tag (for logging). + load_kwargs: Keyword arguments for af_env.load_env(). + + Returns: + The loaded affinetes environment instance. + """ + env = await asyncio.to_thread(af_env.load_env, **load_kwargs) + + # Warm-up call + logger.info("warmup_call_starting", family=family) + try: + await env.list_environments() + logger.info("warmup_call_succeeded", family=family) + except Exception as e: + logger.info( + "warmup_call_absorbed_expected_error", + family=family, + error=str(e)[:100], + ) + + logger.info("eval_environment_loaded", family=family, image=image) + return env + + +async def run_evaluation( + env: Any, + task: Task, + max_timesteps: int, + action_timeout: float, + use_images: bool, + eval_timeout: int, +) -> TaskResult: + """Run an evaluation and build a TaskResult from the response. + + Args: + env: The affinetes environment instance. + task: The task to evaluate. + max_timesteps: Maximum timesteps per episode. + action_timeout: Timeout for miner action responses. + use_images: Whether to include camera images in observations. + eval_timeout: Timeout for the evaluation call. + + Returns: + TaskResult with the evaluation outcome. + """ + result = await env.evaluate( + task_id=task.seed, + model=f"miner-{task.miner_uid}", + base_url=task.miner_endpoint, + env_id=task.env_id, + max_timesteps=max_timesteps, + action_timeout=action_timeout, + use_images=use_images, + _timeout=eval_timeout, + ) + + success = result.get("success", False) + score = result.get("score", 0.0) + extra = result.get("extra", {}) + error = result.get("error") + + return TaskResult( + task_uuid=task.task_uuid, + success=success, + score=score, + total_reward=extra.get("total_reward", 0.0), + timesteps=extra.get("timesteps", 0), + error=error, + ) + + +def force_remove_container(container_name: str) -> None: + """Force-remove a Docker container by name, ignoring errors. + + Args: + container_name: Name of the Docker container to remove. + """ + try: + subprocess.run( + ["docker", "rm", "-f", container_name], + capture_output=True, + text=True, + timeout=5, + check=False, + ) + except Exception as e: + logger.debug("force_remove_container_failed", container=container_name, error=str(e)) diff --git a/kinitro/executor/family_worker.py b/kinitro/executor/family_worker.py index 80d75b9..69a49a9 100644 --- a/kinitro/executor/family_worker.py +++ b/kinitro/executor/family_worker.py @@ -4,15 +4,19 @@ import logging import multiprocessing as mp import os -import subprocess from typing import Any -import affinetes as af_env -import aiohttp import structlog from kinitro.backend.models import Task, TaskResult from kinitro.environments import get_environments_by_family +from kinitro.executor.api_client import APIClient +from kinitro.executor.env_loader import ( + build_load_kwargs, + force_remove_container, + load_and_warmup_env, + run_evaluation, +) # Configure structlog for this subprocess structlog.configure( @@ -59,6 +63,7 @@ def __init__( poll_interval: int, stats_queue: mp.Queue, api_key: str | None = None, + gpu_enabled: bool = False, ): self.family = family self.max_concurrent = max_concurrent @@ -76,31 +81,21 @@ def __init__( self.poll_interval = poll_interval self.stats_queue = stats_queue self.api_key = api_key + self.gpu_enabled = gpu_enabled + + # API client for HTTP communication + self.api_client = APIClient(api_url, executor_id, api_key) # Async primitives (initialized in run()) self.task_queue: asyncio.Queue | None = None self.semaphore: asyncio.Semaphore | None = None self.env: Any = None self.running = False - self._session: aiohttp.ClientSession | None = None # Metrics self.tasks_succeeded = 0 self.tasks_failed = 0 - async def _get_session(self) -> aiohttp.ClientSession: - """Get or create HTTP session.""" - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=30) - self._session = aiohttp.ClientSession(timeout=timeout) - return self._session - - def _get_auth_headers(self) -> dict[str, str]: - """Get authentication headers if API key is configured.""" - if self.api_key: - return {"Authorization": f"Bearer {self.api_key}"} - return {} - async def initialize(self) -> None: """Load the environment container once.""" self.task_queue = asyncio.Queue() @@ -115,44 +110,18 @@ async def initialize(self) -> None: ) # Load environment via affinetes - load_kwargs = { - "image": self.image, - "mode": self.eval_mode, - "mem_limit": self.mem_limit, - "pull": True, - } - - if self.eval_mode == "docker": - load_kwargs.update( - { - "hosts": self.hosts, - "container_name": f"kinitro-eval-{self.executor_id}-{self.family}", - "force_recreate": True, - } - ) - elif self.eval_mode == "basilica": - load_kwargs.update( - { - "cpu_limit": "2000m", - "ttl_buffer": self.eval_timeout + 60, - } - ) - - self.env = await asyncio.to_thread(af_env.load_env, **load_kwargs) - - # Warm-up call - logger.info("warmup_call_starting", family=self.family) - try: - await self.env.list_environments() - logger.info("warmup_call_succeeded", family=self.family) - except Exception as e: - logger.info( - "warmup_call_absorbed_expected_error", - family=self.family, - error=str(e)[:100], - ) + load_kwargs = build_load_kwargs( + image=self.image, + eval_mode=self.eval_mode, + mem_limit=self.mem_limit, + executor_id=self.executor_id, + family=self.family, + hosts=self.hosts, + eval_timeout=self.eval_timeout, + gpu_enabled=self.gpu_enabled, + ) - logger.info("eval_environment_loaded", family=self.family, image=self.image) + self.env = await load_and_warmup_env(self.family, self.image, load_kwargs) async def _fetch_loop(self) -> None: """Producer: fetch tasks from API, push to queue.""" @@ -191,47 +160,8 @@ async def _fetch_loop(self) -> None: async def _fetch_tasks_batch(self, batch_size: int) -> list[Task]: """Fetch a batch of tasks from the API, filtered to this family.""" - session = await self._get_session() - - # Get env_ids for this family and filter at the API level env_ids = get_environments_by_family(self.family) - payload = { - "executor_id": self.executor_id, - "batch_size": batch_size, - "env_ids": env_ids, - } - - try: - async with session.post( - f"{self.api_url}/v1/tasks/fetch", - json=payload, - headers=self._get_auth_headers(), - ) as resp: - if resp.status != 200: - error = await resp.text() - logger.error( - "fetch_tasks_error", - family=self.family, - status=resp.status, - error=error, - ) - return [] - - data = await resp.json() - tasks = [Task(**t) for t in data["tasks"]] - - if tasks: - logger.info( - "tasks_fetched", - family=self.family, - count=len(tasks), - total_pending=data.get("total_pending", 0), - ) - return tasks - - except Exception as e: - logger.error("fetch_tasks_exception", family=self.family, error=str(e)) - return [] + return await self.api_client.fetch_tasks(batch_size=batch_size, env_ids=env_ids) async def _execution_worker(self, worker_id: int) -> None: """Consumer: pull tasks from queue, execute, submit results.""" @@ -292,78 +222,28 @@ async def _execute_task(self, task: Task) -> TaskResult: env_id=task.env_id, ) - result = await self.env.evaluate( - task_id=task.seed, - model=f"miner-{task.miner_uid}", - base_url=task.miner_endpoint, - env_id=task.env_id, + task_result = await run_evaluation( + env=self.env, + task=task, max_timesteps=self.max_timesteps, action_timeout=self.action_timeout, use_images=self.use_images, - _timeout=self.eval_timeout, + eval_timeout=self.eval_timeout, ) - success = result.get("success", False) - score = result.get("score", 0.0) - extra = result.get("extra", {}) - logger.info( "task_executed", family=self.family, task_uuid=task.task_uuid, - success=success, - score=score, + success=task_result.success, + score=task_result.score, ) - return TaskResult( - task_uuid=task.task_uuid, - success=success, - score=score, - total_reward=extra.get("total_reward", 0.0), - timesteps=extra.get("timesteps", 0), - error=None, - ) + return task_result async def _submit_result(self, result: TaskResult) -> None: """Submit a task result to the API.""" - session = await self._get_session() - - payload = { - "executor_id": self.executor_id, - "results": [result.model_dump()], - } - - try: - async with session.post( - f"{self.api_url}/v1/tasks/submit", - json=payload, - headers=self._get_auth_headers(), - ) as resp: - if resp.status != 200: - error = await resp.text() - logger.error( - "submit_result_error", - family=self.family, - task_uuid=result.task_uuid, - status=resp.status, - error=error, - ) - else: - data = await resp.json() - logger.debug( - "result_submitted", - family=self.family, - task_uuid=result.task_uuid, - accepted=data.get("accepted", 0), - ) - - except Exception as e: - logger.error( - "submit_result_exception", - family=self.family, - task_uuid=result.task_uuid, - error=str(e), - ) + await self.api_client.submit_results([result]) async def _report_metrics(self) -> None: """Periodically report metrics to main process.""" @@ -419,8 +299,7 @@ async def _cleanup(self) -> None: logger.info("family_worker_cleaning_up", family=self.family) # Close HTTP session - if self._session and not self._session.closed: - await self._session.close() + await self.api_client.close() # Cleanup environment if self.env: @@ -431,16 +310,7 @@ async def _cleanup(self) -> None: # Force cleanup docker container container_name = f"kinitro-eval-{self.executor_id}-{self.family}" - try: - subprocess.run( - ["docker", "rm", "-f", container_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - except Exception: - pass + force_remove_container(container_name) logger.info("family_worker_stopped", family=self.family) @@ -463,6 +333,7 @@ def run_family_worker( stats_queue: mp.Queue, log_level: str, api_key: str | None = None, + gpu_enabled: bool = False, ) -> None: """Entry point for subprocess (called by multiprocessing).""" # Configure logging for subprocess @@ -485,6 +356,7 @@ def run_family_worker( poll_interval=poll_interval, stats_queue=stats_queue, api_key=api_key, + gpu_enabled=gpu_enabled, ) try: diff --git a/kinitro/executor/worker.py b/kinitro/executor/worker.py index 64c8d1b..2944d0d 100644 --- a/kinitro/executor/worker.py +++ b/kinitro/executor/worker.py @@ -1,35 +1,23 @@ """Worker that executes evaluation tasks using affinetes.""" import asyncio -import subprocess -from dataclasses import dataclass from typing import Any -import affinetes as af_env -import docker.types import structlog from kinitro.backend.models import Task, TaskResult from kinitro.executor.config import ExecutorConfig +from kinitro.executor.env_loader import ( + build_load_kwargs, + force_remove_container, + load_and_warmup_env, + run_evaluation, +) from kinitro.executor.verification import PolicyVerifier, VerificationResult logger = structlog.get_logger() -@dataclass -class EvalEnvConfig: - """Configuration for the evaluation environment.""" - - image: str - mode: str - mem_limit: str - hosts: list[str] - max_timesteps: int - action_timeout: float - eval_timeout: int - use_images: bool - - class Worker: """ Worker that executes evaluation tasks using affinetes. @@ -107,49 +95,18 @@ async def _get_eval_environment(self, env_id: str): ) # Load eval environment via affinetes - load_kwargs = { - "image": image, - "mode": self.config.eval_mode, - "mem_limit": self.config.eval_mem_limit, - "pull": True, - } - - if self.config.eval_mode == "docker": - load_kwargs.update( - { - "hosts": self.config.eval_hosts, - # Unique container name per family - "container_name": f"kinitro-eval-{self.config.executor_id}-{family}", - "force_recreate": True, - } - ) - if self.config.eval_gpu: - load_kwargs["device_requests"] = [ - docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]]) - ] - elif self.config.eval_mode == "basilica": - load_kwargs.update( - { - "cpu_limit": "2000m", - "ttl_buffer": self.config.eval_timeout + 60, - } - ) - - env = await asyncio.to_thread(af_env.load_env, **load_kwargs) - - # Warm-up call - logger.info("warmup_call_starting", family=family) - try: - await env.list_environments() - logger.info("warmup_call_succeeded", family=family) - except Exception as e: - logger.info( - "warmup_call_absorbed_expected_error", - family=family, - error=str(e)[:100], - ) + load_kwargs = build_load_kwargs( + image=image, + eval_mode=self.config.eval_mode, + mem_limit=self.config.eval_mem_limit, + executor_id=self.config.executor_id, + family=family, + hosts=self.config.eval_hosts, + eval_timeout=self.config.eval_timeout, + gpu_enabled=self.config.eval_gpu, + ) - logger.info("eval_environment_loaded", family=family, image=image) + env = await load_and_warmup_env(family, image, load_kwargs) self._envs[family] = env return env @@ -177,36 +134,23 @@ async def execute_task(self, task: Task) -> TaskResult: try: env = await self._get_eval_environment(task.env_id) - result = await env.evaluate( - task_id=task.seed, # Use seed for environment reproducibility - model=f"miner-{task.miner_uid}", # Identifier for logging - base_url=task.miner_endpoint, - env_id=task.env_id, + task_result = await run_evaluation( + env=env, + task=task, max_timesteps=self.config.max_timesteps, action_timeout=self.config.action_timeout, use_images=self.config.use_images, - _timeout=self.config.eval_timeout, + eval_timeout=self.config.eval_timeout, ) - success = result.get("success", False) - score = result.get("score", 0.0) - extra = result.get("extra", {}) - logger.info( "task_executed", task_uuid=task.task_uuid, - success=success, - score=score, + success=task_result.success, + score=task_result.score, ) - return TaskResult( - task_uuid=task.task_uuid, - success=success, - score=score, - total_reward=extra.get("total_reward", 0.0), - timesteps=extra.get("timesteps", 0), - error=None, - ) + return task_result except TimeoutError: logger.warning("task_timeout", task_uuid=task.task_uuid) @@ -348,31 +292,12 @@ def force_cleanup(self) -> None: for family in list(self._envs.keys()): container_name = f"kinitro-eval-{self.config.executor_id}-{family}" logger.info("force_cleanup_container", container=container_name, family=family) - - try: - subprocess.run( - ["docker", "rm", "-f", container_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - except Exception as e: - logger.warning("docker_cleanup_failed", family=family, error=str(e)) + force_remove_container(container_name) # Also try to clean up any containers matching the executor pattern # in case there are orphaned containers from previous runs for family in self.config.eval_images.keys(): container_name = f"kinitro-eval-{self.config.executor_id}-{family}" - try: - subprocess.run( - ["docker", "rm", "-f", container_name], - capture_output=True, - text=True, - timeout=5, - check=False, - ) - except Exception: - pass # Ignore errors for containers that don't exist + force_remove_container(container_name) self._envs.clear() diff --git a/kinitro/executor/worker_process.py b/kinitro/executor/worker_process.py index 2807a33..b6d371e 100644 --- a/kinitro/executor/worker_process.py +++ b/kinitro/executor/worker_process.py @@ -58,6 +58,7 @@ def start(self) -> None: self.stats_queue, self.config.log_level, self.config.api_key, + self.config.eval_gpu, ), name=f"Worker-{self.family}", ) diff --git a/tests/unit/test_crypto.py b/tests/unit/test_crypto.py index 708bd0d..cf6929f 100644 --- a/tests/unit/test_crypto.py +++ b/tests/unit/test_crypto.py @@ -16,20 +16,23 @@ ) -class TestUUIDConversion: - """Tests for UUID <-> bytes conversion.""" +@pytest.fixture() +def keypair(): + return BackendKeypair.generate() - def test_uuid_to_bytes_standard_format(self): - """Standard UUID format with dashes.""" - uuid_str = "95edf2b6-e18b-400a-8398-5573df10e5e4" - result = uuid_to_bytes(uuid_str) - assert len(result) == 16 - assert result.hex() == "95edf2b6e18b400a83985573df10e5e4" +class TestUUIDConversion: + """Tests for UUID <-> bytes conversion.""" - def test_uuid_to_bytes_no_dashes(self): - """UUID format without dashes (already hex).""" - uuid_str = "95edf2b6e18b400a83985573df10e5e4" + @pytest.mark.parametrize( + "uuid_str", + [ + pytest.param("95edf2b6-e18b-400a-8398-5573df10e5e4", id="with_dashes"), + pytest.param("95edf2b6e18b400a83985573df10e5e4", id="no_dashes"), + ], + ) + def test_uuid_to_bytes(self, uuid_str): + """Both UUID formats should produce the same 16-byte result.""" result = uuid_to_bytes(uuid_str) assert len(result) == 16 @@ -57,44 +60,38 @@ def test_bytes_to_uuid_invalid_length(self): class TestBackendKeypair: """Tests for BackendKeypair class.""" - def test_generate_creates_valid_keypair(self): + def test_generate_creates_valid_keypair(self, keypair): """Generate should create a valid keypair.""" - keypair = BackendKeypair.generate() - assert keypair.private_key is not None assert keypair.public_key is not None - def test_public_key_hex_length(self): + def test_public_key_hex_length(self, keypair): """Public key hex should be 64 characters (32 bytes).""" - keypair = BackendKeypair.generate() pub_hex = keypair.public_key_hex() assert len(pub_hex) == 64 # Should be valid hex bytes.fromhex(pub_hex) - def test_private_key_hex_length(self): + def test_private_key_hex_length(self, keypair): """Private key hex should be 64 characters (32 bytes).""" - keypair = BackendKeypair.generate() priv_hex = keypair.private_key_hex() assert len(priv_hex) == 64 # Should be valid hex bytes.fromhex(priv_hex) - def test_from_private_key_hex_roundtrip(self): + def test_from_private_key_hex_roundtrip(self, keypair): """Load keypair from hex should preserve keys.""" - original = BackendKeypair.generate() - priv_hex = original.private_key_hex() + priv_hex = keypair.private_key_hex() restored = BackendKeypair.from_private_key_hex(priv_hex) - assert restored.public_key_hex() == original.public_key_hex() - assert restored.private_key_hex() == original.private_key_hex() + assert restored.public_key_hex() == keypair.public_key_hex() + assert restored.private_key_hex() == keypair.private_key_hex() - def test_from_private_key_file(self, tmp_path): + def test_from_private_key_file(self, keypair, tmp_path): """Load keypair from file.""" - keypair = BackendKeypair.generate() key_file = tmp_path / "test.key" keypair.save_private_key(key_file) @@ -102,18 +99,16 @@ def test_from_private_key_file(self, tmp_path): assert restored.public_key_hex() == keypair.public_key_hex() - def test_save_private_key_permissions(self, tmp_path): + def test_save_private_key_permissions(self, keypair, tmp_path): """Private key file should have restricted permissions (0600).""" - keypair = BackendKeypair.generate() key_file = tmp_path / "test.key" keypair.save_private_key(key_file) mode = os.stat(key_file).st_mode & 0o777 assert mode == 0o600 - def test_save_public_key(self, tmp_path): + def test_save_public_key(self, keypair, tmp_path): """Save and read public key.""" - keypair = BackendKeypair.generate() pub_file = tmp_path / "test.pub" keypair.save_public_key(pub_file) @@ -126,8 +121,7 @@ class TestLoadPublicKey: def test_load_valid_public_key(self): """Load a valid public key from hex.""" - keypair = BackendKeypair.generate() - pub_hex = keypair.public_key_hex() + pub_hex = BackendKeypair.generate().public_key_hex() loaded = load_public_key(pub_hex) @@ -143,64 +137,51 @@ def test_load_invalid_length(self): class TestEncryptDecrypt: """Tests for encrypt/decrypt deployment ID.""" - def test_encrypt_decrypt_roundtrip(self): - """Encrypt and decrypt should return original value.""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - - encrypted = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) - decrypted = decrypt_deployment_id(encrypted, keypair.private_key) - - assert decrypted == deployment_id + DEPLOYMENT_ID = "95edf2b6-e18b-400a-8398-5573df10e5e4" - def test_encrypt_with_key_object(self): - """Encrypt should accept key object directly.""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" + @pytest.mark.parametrize( + "use_key_object", + [ + pytest.param(False, id="hex_string"), + pytest.param(True, id="key_object"), + ], + ) + def test_encrypt_decrypt_roundtrip(self, keypair, use_key_object): + """Encrypt and decrypt should return original value (hex string or key object).""" + pub_key = keypair.public_key if use_key_object else keypair.public_key_hex() - encrypted = encrypt_deployment_id(deployment_id, keypair.public_key) + encrypted = encrypt_deployment_id(self.DEPLOYMENT_ID, pub_key) decrypted = decrypt_deployment_id(encrypted, keypair.private_key) - assert decrypted == deployment_id + assert decrypted == self.DEPLOYMENT_ID - def test_encrypted_blob_is_base85(self): + def test_encrypted_blob_is_base85(self, keypair): """Encrypted blob should be base85 encoded.""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - - encrypted = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) + encrypted = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) # Should be decodable as base85 decoded = base64.b85decode(encrypted.encode("ascii")) assert len(decoded) == 64 # 32 + 16 + 16 (pubkey + ciphertext + tag, nonce derived) - def test_encrypted_blob_length(self): + def test_encrypted_blob_length(self, keypair): """Encrypted blob should be ~95 characters (base85 of 76 bytes).""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - - encrypted = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) + encrypted = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) # Base85: 64 bytes -> ceil(64 * 5 / 4) = 80 characters assert len(encrypted) == 80 - def test_decrypt_with_wrong_key_fails(self): + def test_decrypt_with_wrong_key_fails(self, keypair): """Decryption with wrong key should fail.""" - keypair1 = BackendKeypair.generate() keypair2 = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - encrypted = encrypt_deployment_id(deployment_id, keypair1.public_key_hex()) + encrypted = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) with pytest.raises(ValueError, match="Decryption failed"): decrypt_deployment_id(encrypted, keypair2.private_key) - def test_decrypt_tampered_data_fails(self): + def test_decrypt_tampered_data_fails(self, keypair): """Decryption of tampered data should fail.""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - - encrypted = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) + encrypted = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) # Tamper with the encrypted blob tampered = encrypted[:-5] + "XXXXX" @@ -208,33 +189,27 @@ def test_decrypt_tampered_data_fails(self): with pytest.raises(ValueError): decrypt_deployment_id(tampered, keypair.private_key) - def test_each_encryption_is_unique(self): + def test_each_encryption_is_unique(self, keypair): """Each encryption should produce different output (fresh ephemeral key).""" - keypair = BackendKeypair.generate() - deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - - encrypted1 = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) - encrypted2 = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) + encrypted1 = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) + encrypted2 = encrypt_deployment_id(self.DEPLOYMENT_ID, keypair.public_key_hex()) # Same plaintext should produce different ciphertext (different ephemeral keys) assert encrypted1 != encrypted2 # Both should decrypt to the same value - assert decrypt_deployment_id(encrypted1, keypair.private_key) == deployment_id - assert decrypt_deployment_id(encrypted2, keypair.private_key) == deployment_id + assert decrypt_deployment_id(encrypted1, keypair.private_key) == self.DEPLOYMENT_ID + assert decrypt_deployment_id(encrypted2, keypair.private_key) == self.DEPLOYMENT_ID class TestIntegration: """Integration tests for the full encryption flow.""" - def test_full_commitment_flow(self): + def test_full_commitment_flow(self, keypair): """Test full flow: generate keys, encrypt, parse, decrypt.""" - # Backend generates keypair - backend_keypair = BackendKeypair.generate() - # Miner encrypts their deployment ID deployment_id = "95edf2b6-e18b-400a-8398-5573df10e5e4" - encrypted_blob = encrypt_deployment_id(deployment_id, backend_keypair.public_key_hex()) + encrypted_blob = encrypt_deployment_id(deployment_id, keypair.public_key_hex()) # Miner creates commitment (colon-separated format) commitment = f"user/policy:abc123def456:e:{encrypted_blob}" @@ -251,9 +226,7 @@ def test_full_commitment_flow(self): assert parsed["encrypted_deployment"] == encrypted_blob # Backend decrypts the endpoint - decrypted = decrypt_deployment_id( - parsed["encrypted_deployment"], backend_keypair.private_key - ) + decrypted = decrypt_deployment_id(parsed["encrypted_deployment"], keypair.private_key) assert decrypted == deployment_id diff --git a/tests/unit/test_cycle_isolation.py b/tests/unit/test_cycle_isolation.py index ed91b91..c0e40a3 100644 --- a/tests/unit/test_cycle_isolation.py +++ b/tests/unit/test_cycle_isolation.py @@ -13,6 +13,34 @@ from kinitro.backend.storage import Storage +def _make_mock_cycle( + cycle_id: int = 1, + status: str = EvaluationCycleStatus.RUNNING.value, +) -> MagicMock: + """Create a mock EvaluationCycleORM.""" + mock = MagicMock(spec=EvaluationCycleORM) + mock.id = cycle_id + mock.status = status + return mock + + +def _make_mock_task(status: str = TaskStatus.PENDING.value) -> MagicMock: + """Create a mock TaskPoolORM.""" + mock = MagicMock(spec=TaskPoolORM) + mock.status = status + return mock + + +def _mock_execute_results(*results_lists: list) -> AsyncMock: + """Build an AsyncMock side_effect from lists of ORM objects per query.""" + side_effects = [] + for items in results_lists: + result = MagicMock() + result.scalars.return_value.all.return_value = items + side_effects.append(result) + return AsyncMock(side_effect=side_effects) + + class TestCancelIncompleteCycles: """Tests for Storage.cancel_incomplete_cycles().""" @@ -25,10 +53,7 @@ def mock_session(self): @pytest.mark.asyncio async def test_no_incomplete_cycles(self, mock_session): """When no incomplete cycles exist, nothing is cancelled.""" - # Mock execute to return empty result for cycles query - mock_result = MagicMock() - mock_result.scalars.return_value.all.return_value = [] - mock_session.execute = AsyncMock(return_value=mock_result) + mock_session.execute = _mock_execute_results([]) storage = Storage("postgresql+asyncpg://test:test@localhost/test") @@ -40,25 +65,11 @@ async def test_no_incomplete_cycles(self, mock_session): @pytest.mark.asyncio async def test_cancels_running_cycle_and_tasks(self, mock_session): """Running cycles and their pending/assigned tasks are cancelled.""" - # Create mock cycle - mock_cycle = MagicMock(spec=EvaluationCycleORM) - mock_cycle.id = 1 - mock_cycle.status = EvaluationCycleStatus.RUNNING.value - - # Create mock tasks - mock_task1 = MagicMock(spec=TaskPoolORM) - mock_task1.status = TaskStatus.PENDING.value - mock_task2 = MagicMock(spec=TaskPoolORM) - mock_task2.status = TaskStatus.ASSIGNED.value - - # Mock execute - first call returns cycles, second returns tasks - cycles_result = MagicMock() - cycles_result.scalars.return_value.all.return_value = [mock_cycle] + mock_cycle = _make_mock_cycle() + mock_task1 = _make_mock_task(TaskStatus.PENDING.value) + mock_task2 = _make_mock_task(TaskStatus.ASSIGNED.value) - tasks_result = MagicMock() - tasks_result.scalars.return_value.all.return_value = [mock_task1, mock_task2] - - mock_session.execute = AsyncMock(side_effect=[cycles_result, tasks_result]) + mock_session.execute = _mock_execute_results([mock_cycle], [mock_task1, mock_task2]) storage = Storage("postgresql+asyncpg://test:test@localhost/test") @@ -80,23 +91,10 @@ async def test_cancels_running_cycle_and_tasks(self, mock_session): @pytest.mark.asyncio async def test_leaves_completed_tasks_unchanged(self, mock_session): """Completed/failed tasks from incomplete cycles are not modified.""" - # Create mock cycle - mock_cycle = MagicMock(spec=EvaluationCycleORM) - mock_cycle.id = 1 - mock_cycle.status = EvaluationCycleStatus.RUNNING.value - - # Only pending task (completed tasks not returned by query) - mock_task = MagicMock(spec=TaskPoolORM) - mock_task.status = TaskStatus.PENDING.value - - cycles_result = MagicMock() - cycles_result.scalars.return_value.all.return_value = [mock_cycle] - - # Query only returns pending/assigned, not completed - tasks_result = MagicMock() - tasks_result.scalars.return_value.all.return_value = [mock_task] + mock_cycle = _make_mock_cycle() + mock_task = _make_mock_task(TaskStatus.PENDING.value) - mock_session.execute = AsyncMock(side_effect=[cycles_result, tasks_result]) + mock_session.execute = _mock_execute_results([mock_cycle], [mock_task]) storage = Storage("postgresql+asyncpg://test:test@localhost/test") diff --git a/tests/unit/test_genesis.py b/tests/unit/test_genesis.py index 14eeffc..db2726b 100644 --- a/tests/unit/test_genesis.py +++ b/tests/unit/test_genesis.py @@ -75,6 +75,21 @@ def _make_task_spec( ) +def _make_g1_env() -> G1Environment: + """Create a G1Environment without full __init__ (for unit-testing reward/success). + + Invariant: _compute_reward and _check_success must only use their explicit + arguments and never access self.* attributes. If that changes, these tests + will need a proper mock or fixture with instance state. + """ + return object.__new__(G1Environment) + + +@pytest.fixture() +def g1_env() -> G1Environment: + return _make_g1_env() + + def _make_test_objects() -> list[SceneObject]: """Create a mixed set of pickupable + landmark objects for task generation tests.""" return [ @@ -256,66 +271,69 @@ def test_color_palette_distinct(self): class TestCheckTaskFeasibility: """Tests for check_task_feasibility function.""" - def test_navigate_always_feasible(self): - """NAVIGATE should be feasible for any object.""" - obj = _make_scene_object(pickupable=False) - feasible, _ = check_task_feasibility(TaskType.NAVIGATE, obj) - assert feasible is True - - def test_pickup_feasible(self): - """PICKUP should be feasible for pickupable objects.""" - obj = _make_scene_object(pickupable=True) - feasible, _ = check_task_feasibility(TaskType.PICKUP, obj) - assert feasible is True - - def test_pickup_infeasible_not_pickupable(self): - """PICKUP should be infeasible for non-pickupable objects.""" - obj = _make_scene_object(pickupable=False) - feasible, reason = check_task_feasibility(TaskType.PICKUP, obj) - assert feasible is False - assert "not pickupable" in reason - - def test_pickup_infeasible_already_picked_up(self): - """PICKUP should be infeasible when object is already picked up.""" - obj = _make_scene_object(pickupable=True, is_picked_up=True) - feasible, reason = check_task_feasibility(TaskType.PICKUP, obj) - assert feasible is False - assert "already picked up" in reason - - def test_place_feasible(self): - """PLACE should be feasible with pickupable target + destination.""" - target = _make_scene_object(pickupable=True) - dest = _make_scene_object(object_id="dest", pickupable=False) - feasible, _ = check_task_feasibility(TaskType.PLACE, target, destination=dest) - assert feasible is True - - def test_place_infeasible_no_destination(self): - """PLACE should be infeasible without a destination.""" - target = _make_scene_object(pickupable=True) - feasible, reason = check_task_feasibility(TaskType.PLACE, target, destination=None) - assert feasible is False - assert "destination" in reason.lower() - - def test_push_feasible(self): - """PUSH should be feasible with target + different destination.""" - target = _make_scene_object(object_id="a") - dest = _make_scene_object(object_id="b") - feasible, _ = check_task_feasibility(TaskType.PUSH, target, destination=dest) - assert feasible is True - - def test_push_infeasible_same_object(self): - """PUSH should be infeasible when target == destination.""" - obj = _make_scene_object(object_id="same") - feasible, reason = check_task_feasibility(TaskType.PUSH, obj, destination=obj) - assert feasible is False - assert "itself" in reason.lower() - - def test_push_infeasible_no_destination(self): - """PUSH should be infeasible without a destination.""" - target = _make_scene_object() - feasible, reason = check_task_feasibility(TaskType.PUSH, target, destination=None) + @pytest.mark.parametrize( + "task_type, obj_kwargs, dest_factory, expected", + [ + pytest.param(TaskType.NAVIGATE, {"pickupable": False}, None, True, id="navigate_any"), + pytest.param(TaskType.PICKUP, {"pickupable": True}, None, True, id="pickup_ok"), + pytest.param( + TaskType.PLACE, + {"pickupable": True}, + lambda: _make_scene_object(object_id="dest", pickupable=False), + True, + id="place_ok", + ), + pytest.param( + TaskType.PUSH, + {"object_id": "a"}, + lambda: _make_scene_object(object_id="b"), + True, + id="push_ok", + ), + ], + ) + def test_feasible_cases(self, task_type, obj_kwargs, dest_factory, expected): + obj = _make_scene_object(**obj_kwargs) + dest = dest_factory() if dest_factory else None + feasible, _ = check_task_feasibility(task_type, obj, destination=dest) + assert feasible is expected + + @pytest.mark.parametrize( + "task_type, obj_kwargs, dest_factory, reason_substr", + [ + pytest.param( + TaskType.PICKUP, + {"pickupable": False}, + None, + "not pickupable", + id="pickup_not_pickupable", + ), + pytest.param( + TaskType.PICKUP, + {"pickupable": True, "is_picked_up": True}, + None, + "already picked up", + id="pickup_already_picked", + ), + pytest.param( + TaskType.PLACE, {"pickupable": True}, None, "destination", id="place_no_dest" + ), + pytest.param( + TaskType.PUSH, + {"object_id": "same"}, + lambda: _make_scene_object(object_id="same"), + "itself", + id="push_same_object", + ), + pytest.param(TaskType.PUSH, {}, None, "destination", id="push_no_dest"), + ], + ) + def test_infeasible_cases(self, task_type, obj_kwargs, dest_factory, reason_substr): + obj = _make_scene_object(**obj_kwargs) + dest = dest_factory() if dest_factory else None + feasible, reason = check_task_feasibility(task_type, obj, destination=dest) assert feasible is False - assert "destination" in reason.lower() + assert reason_substr in reason.lower() def test_robot_capability_filtering_unsupported(self): """Task type not in robot_supported_tasks should be infeasible.""" @@ -609,44 +627,37 @@ def test_generate_task_with_specific_type(self): class TestG1Reward: """Tests for G1Environment._compute_reward using object.__new__() bypass.""" - def _make_env(self) -> G1Environment: - return object.__new__(G1Environment) - - def test_navigate_alive_bonus(self): + def test_navigate_alive_bonus(self, g1_env): """NAVIGATE reward should always include alive bonus.""" - env = self._make_env() robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} spec = _make_task_spec(task_type=TaskType.NAVIGATE, target_position=[5.0, 5.0, 0.0]) - reward = env._compute_reward(robot_state, {}, spec) + reward = g1_env._compute_reward(robot_state, {}, spec) assert reward >= 0.01 # alive bonus - def test_navigate_reward_increases_closer(self): + def test_navigate_reward_increases_closer(self, g1_env): """NAVIGATE reward should increase as distance decreases.""" - env = self._make_env() spec = _make_task_spec(task_type=TaskType.NAVIGATE, target_position=[3.0, 0.0, 0.0]) far_state = {"base_pos": np.array([0.0, 0.0, 0.75])} near_state = {"base_pos": np.array([2.5, 0.0, 0.75])} - r_far = env._compute_reward(far_state, {}, spec) - r_near = env._compute_reward(near_state, {}, spec) + r_far = g1_env._compute_reward(far_state, {}, spec) + r_near = g1_env._compute_reward(near_state, {}, spec) assert r_near > r_far - def test_navigate_high_reward_near_target(self): + def test_navigate_high_reward_near_target(self, g1_env): """NAVIGATE reward should be higher when very close to target.""" - env = self._make_env() spec = _make_task_spec(task_type=TaskType.NAVIGATE, target_position=[0.1, 0.0, 0.0]) robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - reward = env._compute_reward(robot_state, {}, spec) + reward = g1_env._compute_reward(robot_state, {}, spec) assert reward > 0.01 # more than just alive bonus - def test_pickup_approach_reward(self): + def test_pickup_approach_reward(self, g1_env): """PICKUP should give approach reward when near the object.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PICKUP, target_object_id="obj_00", @@ -656,12 +667,11 @@ def test_pickup_approach_reward(self): robot_state = {"base_pos": np.array([0.8, 0.0, 0.75])} obj_states = {"obj_00": np.array([1.0, 0.0, 0.05])} - reward = env._compute_reward(robot_state, obj_states, spec) + reward = g1_env._compute_reward(robot_state, obj_states, spec) assert reward > 0.01 # alive + approach - def test_pickup_lift_bonus(self): + def test_pickup_lift_bonus(self, g1_env): """PICKUP should give large bonus when object is lifted above threshold.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PICKUP, target_object_id="obj_00", @@ -671,12 +681,11 @@ def test_pickup_lift_bonus(self): robot_state = {"base_pos": np.array([1.0, 0.0, 0.75])} obj_states = {"obj_00": np.array([1.0, 0.0, 0.5])} # lifted well above 0.05 + 0.15 - reward = env._compute_reward(robot_state, obj_states, spec) + reward = g1_env._compute_reward(robot_state, obj_states, spec) assert reward >= 1.0 # lift bonus of 1.0 - def test_pickup_no_lift_bonus_below_threshold(self): + def test_pickup_no_lift_bonus_below_threshold(self, g1_env): """PICKUP should not give lift bonus when height below threshold.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PICKUP, target_object_id="obj_00", @@ -686,12 +695,11 @@ def test_pickup_no_lift_bonus_below_threshold(self): robot_state = {"base_pos": np.array([1.0, 0.0, 0.75])} obj_states = {"obj_00": np.array([1.0, 0.0, 0.1])} # only 0.05 above, < 0.15 - reward = env._compute_reward(robot_state, obj_states, spec) + reward = g1_env._compute_reward(robot_state, obj_states, spec) assert reward < 1.0 # no lift bonus - def test_pickup_missing_object_only_alive_bonus(self): + def test_pickup_missing_object_only_alive_bonus(self, g1_env): """PICKUP should return only alive bonus when object not in states.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PICKUP, target_object_id="obj_missing", @@ -699,12 +707,11 @@ def test_pickup_missing_object_only_alive_bonus(self): ) robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - reward = env._compute_reward(robot_state, {}, spec) + reward = g1_env._compute_reward(robot_state, {}, spec) assert abs(reward - 0.01) < 1e-6 - def test_place_reward_based_on_distance(self): + def test_place_reward_based_on_distance(self, g1_env): """PLACE reward should depend on object-to-destination distance.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PLACE, target_object_id="obj_00", @@ -714,17 +721,16 @@ def test_place_reward_based_on_distance(self): # Object far from destination far = {"obj_00": np.array([0.0, 0.0, 0.05])} - r_far = env._compute_reward(robot_state, far, spec) + r_far = g1_env._compute_reward(robot_state, far, spec) # Object close to destination close = {"obj_00": np.array([2.9, 2.9, 0.1])} - r_close = env._compute_reward(robot_state, close, spec) + r_close = g1_env._compute_reward(robot_state, close, spec) assert r_close > r_far - def test_push_reward_based_on_xy_distance(self): + def test_push_reward_based_on_xy_distance(self, g1_env): """PUSH reward should depend on XY distance to destination.""" - env = self._make_env() spec = _make_task_spec( task_type=TaskType.PUSH, target_object_id="obj_00", @@ -735,14 +741,13 @@ def test_push_reward_based_on_xy_distance(self): far = {"obj_00": np.array([0.0, 0.0, 0.05])} close = {"obj_00": np.array([2.8, 0.0, 0.05])} - r_far = env._compute_reward(robot_state, far, spec) - r_close = env._compute_reward(robot_state, close, spec) + r_far = g1_env._compute_reward(robot_state, far, spec) + r_close = g1_env._compute_reward(robot_state, close, spec) assert r_close > r_far - def test_all_rewards_non_negative(self): + def test_all_rewards_non_negative(self, g1_env): """All task types should produce non-negative rewards (alive bonus).""" - env = self._make_env() robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} obj_states = {"obj_00": np.array([2.0, 2.0, 0.05])} @@ -753,7 +758,7 @@ def test_all_rewards_non_negative(self): destination_position=[3.0, 3.0, 0.1], initial_state={"initial_height": 0.05}, ) - reward = env._compute_reward(robot_state, obj_states, spec) + reward = g1_env._compute_reward(robot_state, obj_states, spec) assert reward >= 0.0, f"{task_type} produced negative reward: {reward}" @@ -765,127 +770,95 @@ def test_all_rewards_non_negative(self): class TestG1Success: """Tests for G1Environment._check_success using object.__new__() bypass.""" - def _make_env(self) -> G1Environment: - return object.__new__(G1Environment) - - def test_navigate_success_close(self): - """NAVIGATE should succeed when distance < 0.5.""" - env = self._make_env() - robot_state = {"base_pos": np.array([2.8, 0.0, 0.75])} - spec = _make_task_spec(task_type=TaskType.NAVIGATE, target_position=[3.0, 0.0, 0.0]) - - assert env._check_success(robot_state, {}, spec) is True - - def test_navigate_failure_far(self): - """NAVIGATE should fail when distance >= 0.5.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec(task_type=TaskType.NAVIGATE, target_position=[3.0, 0.0, 0.0]) - - assert env._check_success(robot_state, {}, spec) is False - - def test_pickup_success_lifted(self): - """PICKUP should succeed when lifted > 0.15 above initial height.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PICKUP, - target_object_id="obj_00", - initial_state={"initial_height": 0.05}, - ) - obj_states = {"obj_00": np.array([1.0, 0.0, 0.25])} # 0.2 above initial - - assert env._check_success(robot_state, obj_states, spec) is True - - def test_pickup_failure_not_lifted(self): - """PICKUP should fail when not lifted enough.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PICKUP, - target_object_id="obj_00", - initial_state={"initial_height": 0.05}, - ) - obj_states = {"obj_00": np.array([1.0, 0.0, 0.1])} # only 0.05 above - - assert env._check_success(robot_state, obj_states, spec) is False - - def test_pickup_failure_object_missing(self): - """PICKUP should fail when object not in states dict.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PICKUP, - target_object_id="obj_missing", - initial_state={"initial_height": 0.05}, - ) - - assert env._check_success(robot_state, {}, spec) is False - - def test_place_success_within_threshold(self): - """PLACE should succeed when object within 0.3 of destination.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PLACE, - target_object_id="obj_00", - destination_position=[3.0, 3.0, 0.1], - ) - obj_states = {"obj_00": np.array([3.1, 3.1, 0.1])} # ~0.14 away - - assert env._check_success(robot_state, obj_states, spec) is True - - def test_place_failure_too_far(self): - """PLACE should fail when object too far from destination.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PLACE, - target_object_id="obj_00", - destination_position=[3.0, 3.0, 0.1], - ) - obj_states = {"obj_00": np.array([0.0, 0.0, 0.05])} - - assert env._check_success(robot_state, obj_states, spec) is False - - def test_place_failure_destination_none(self): - """PLACE should fail when destination is None.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PLACE, - target_object_id="obj_00", - destination_position=None, - ) - obj_states = {"obj_00": np.array([0.0, 0.0, 0.05])} - - assert env._check_success(robot_state, obj_states, spec) is False - - def test_push_success_close(self): - """PUSH should succeed when XY distance < 0.5.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PUSH, - target_object_id="obj_00", - destination_position=[3.0, 0.0, 0.05], - ) - obj_states = {"obj_00": np.array([2.8, 0.0, 0.05])} # 0.2 away in XY - - assert env._check_success(robot_state, obj_states, spec) is True - - def test_push_failure_too_far(self): - """PUSH should fail when XY distance >= 0.5.""" - env = self._make_env() - robot_state = {"base_pos": np.array([0.0, 0.0, 0.75])} - spec = _make_task_spec( - task_type=TaskType.PUSH, - target_object_id="obj_00", - destination_position=[3.0, 0.0, 0.05], - ) - obj_states = {"obj_00": np.array([0.0, 0.0, 0.05])} - - assert env._check_success(robot_state, obj_states, spec) is False + @pytest.mark.parametrize( + "task_type, spec_kwargs, robot_pos, obj_states, expected", + [ + pytest.param( + TaskType.NAVIGATE, + {"target_position": [3.0, 0.0, 0.0]}, + [2.8, 0.0, 0.75], + {}, + True, + id="navigate_success_close", + ), + pytest.param( + TaskType.NAVIGATE, + {"target_position": [3.0, 0.0, 0.0]}, + [0.0, 0.0, 0.75], + {}, + False, + id="navigate_failure_far", + ), + pytest.param( + TaskType.PICKUP, + {"target_object_id": "obj_00", "initial_state": {"initial_height": 0.05}}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([1.0, 0.0, 0.25])}, + True, + id="pickup_success_lifted", + ), + pytest.param( + TaskType.PICKUP, + {"target_object_id": "obj_00", "initial_state": {"initial_height": 0.05}}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([1.0, 0.0, 0.1])}, + False, + id="pickup_failure_not_lifted", + ), + pytest.param( + TaskType.PICKUP, + {"target_object_id": "obj_missing", "initial_state": {"initial_height": 0.05}}, + [0.0, 0.0, 0.75], + {}, + False, + id="pickup_failure_missing", + ), + pytest.param( + TaskType.PLACE, + {"target_object_id": "obj_00", "destination_position": [3.0, 3.0, 0.1]}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([3.1, 3.1, 0.1])}, + True, + id="place_success_within_threshold", + ), + pytest.param( + TaskType.PLACE, + {"target_object_id": "obj_00", "destination_position": [3.0, 3.0, 0.1]}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([0.0, 0.0, 0.05])}, + False, + id="place_failure_too_far", + ), + pytest.param( + TaskType.PLACE, + {"target_object_id": "obj_00", "destination_position": None}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([0.0, 0.0, 0.05])}, + False, + id="place_failure_dest_none", + ), + pytest.param( + TaskType.PUSH, + {"target_object_id": "obj_00", "destination_position": [3.0, 0.0, 0.05]}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([2.8, 0.0, 0.05])}, + True, + id="push_success_close", + ), + pytest.param( + TaskType.PUSH, + {"target_object_id": "obj_00", "destination_position": [3.0, 0.0, 0.05]}, + [0.0, 0.0, 0.75], + {"obj_00": np.array([0.0, 0.0, 0.05])}, + False, + id="push_failure_too_far", + ), + ], + ) + def test_check_success(self, g1_env, task_type, spec_kwargs, robot_pos, obj_states, expected): + robot_state = {"base_pos": np.array(robot_pos)} + spec = _make_task_spec(task_type=task_type, **spec_kwargs) + assert g1_env._check_success(robot_state, obj_states, spec) is expected # =============================================================================