diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9bce72b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,398 @@ +# Kinitro Development Guide + +Kinitro is a Bittensor subnet for reinforcement learning evaluation. Miners submit trained agents, validators evaluate them against standardized benchmarks, and the network rewards top performers. + +## Quick Reference + +```bash +# Activate virtual environment +cd /root/dev/kinitro && source .venv/bin/activate + +# Run tests +python3 -m pytest src/backend/tests/ src/backend/scoring/tests/ src/evaluator/executors/tests/ src/evaluator/providers/tests/ -v + +# Run specific component +python3 -m backend # Start backend service +python3 -m validator # Start validator node +python3 -m evaluator # Start evaluator (requires Ray + Kubernetes) +``` + +## Architecture Overview + +``` + ┌─────────────────────┐ + │ Bittensor Chain │ + │ (Commitments + │ + │ Weights) │ + └──────────┬──────────┘ + │ + ┌──────────────────────────┼──────────────────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ + │ Miner │ │ Backend │ │ Validator │ + │ (Submission) │ │ (FastAPI) │ │(Polls /weights)│ + └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ + │ │ │ + │ upload artifact │ GET /weights │ + └─────────────────────────►│◄─────────────────────────┘ + │ + │ WebSocket (direct) + │ + ▼ + ┌─────────────────┐ + │ Evaluator │ + │ (Ray + K8s) │ + └───────┬─────────┘ + │ + │ RPC + ▼ + ┌────────────────┐ + │ Miner Pod │ + │ (Agent Code) │ + └────────────────┘ +``` + +**Key Points:** +- **Evaluators connect directly** to backend via WebSocket (no pgqueuer relay) +- **Validators poll** the `/weights` endpoint and set weights on chain (no WebSocket) +- **Simplified architecture** with fewer moving parts + +## Directory Structure + +``` +src/ +├── backend/ # Central orchestration service +│ ├── service.py # Main BackendService (composition pattern) +│ ├── scoring_engine.py # Eligibility + scoring + weights +│ ├── scoring/ # Pluggable scoring strategies +│ │ ├── strategies.py # ScoringStrategy protocol + RLRolloutScoringStrategy +│ │ └── registry.py # ScoringStrategyRegistry +│ ├── evaluator_hub.py # Direct evaluator WebSocket connections +│ ├── chain_monitor.py # Bittensor commitment scanning +│ ├── job_scheduler.py # Job creation + broadcasting to evaluators +│ ├── endpoints.py # FastAPI routes +│ └── models.py # SQLModel database models +│ +├── validator/ # Weight-setting service (polls backend) +│ ├── websocket_validator.py # Polls /weights, sets on chain +│ └── db/ # Local state cache +│ +├── evaluator/ # Job execution engine (connects to backend) +│ ├── orchestrator.py # Main job lifecycle management +│ ├── backend_client.py # WebSocket client for backend connection +│ ├── executors/ # Task type implementations +│ │ ├── registry.py # ExecutorRegistry +│ │ └── rl_rollout.py # RLRolloutExecutor +│ ├── providers/ # Environment providers +│ │ ├── registry.py # ProviderRegistry +│ │ ├── metaworld_provider.py +│ │ └── swarm_provider.py +│ ├── rollout/ # Ray actors for RL execution +│ ├── containers/ # Kubernetes pod management +│ └── rpc/ # Cap'n Proto agent communication +│ +├── core/ # Shared modules +│ ├── tasks.py # TaskSpec, TaskResult, TaskExecutor protocol +│ ├── messages.py # WebSocket message types +│ ├── db/models.py # SnowflakeId, EvaluationStatus +│ └── chain.py # Bittensor helpers +│ +└── miner/ # Submission CLI + └── commands/ # upload, commit, local-eval +``` + +## Key Abstractions + +### Task Abstraction Layer (`src/core/tasks.py`) + +Universal interface for different task types (RL rollouts, training, browser tasks, etc.): + +```python +class TaskType(StrEnum): + RL_ROLLOUT = "rl_rollout" + # Future: TRAINING_RUN, BROWSER_TASK, DATASET_EVAL + +@dataclass +class TaskSpec: + task_type: TaskType + task_id: str + config: dict[str, Any] + timeout: timedelta + resources: ResourceSpec + submission_id: int + competition_id: str + artifact_url: str + +@dataclass +class TaskResult: + task_id: str + success: bool + metrics: dict[str, float] # Task-specific metrics + artifacts: dict[str, str] # name -> S3 key + +class TaskExecutor(Protocol): + @property + def task_type(self) -> TaskType: ... + async def validate_spec(self, spec: TaskSpec) -> list[str]: ... + async def setup(self, spec: TaskSpec) -> TaskContext: ... + async def execute(self, context: TaskContext) -> TaskResult: ... + async def teardown(self, context: TaskContext) -> None: ... +``` + +### Executor Registry (`src/evaluator/executors/registry.py`) + +```python +# Register executors +ExecutorRegistry.register(RLRolloutExecutor(config)) + +# Dispatch by task type +executor = ExecutorRegistry.get(TaskType.RL_ROLLOUT) +context = await executor.setup(spec) +result = await executor.execute(context) +``` + +### Scoring Strategies (`src/backend/scoring/`) + +```python +class ScoringStrategy(Protocol): + @property + def task_type(self) -> TaskType: ... + def extract_metrics(self, result) -> ScoringMetrics: ... + def check_eligibility(self, metrics, competition) -> EligibilityResult: ... + def compute_score(self, metrics, competition) -> float: ... + def compare(self, a, b) -> int: ... # For ranking + +# Usage in ScoringEngine +strategy = self.strategy_registry.get(competition.task_type) +metrics = strategy.extract_metrics(result) +if strategy.check_eligibility(metrics, competition).eligible: + score = strategy.compute_score(metrics, competition) +``` + +### Provider Registry (`src/evaluator/providers/registry.py`) + +```python +class EnvironmentProvider(Protocol): + @property + def name(self) -> str: ... + def get_benchmark_specs(self, config) -> list[BenchmarkSpec]: ... + def get_env_specs(self, benchmark_spec) -> list[EnvSpec]: ... + def make_env(self, spec) -> gym.Env: ... + +# Built-in providers +ProviderRegistry.register(MetaWorldProvider()) +ProviderRegistry.register(SwarmProvider()) +``` + +## Database Models + +### Backend (`src/backend/models.py`) + +``` +Competition +├── id, name, description +├── benchmarks: JSON (list of benchmark specs) +├── points: int (weight allocation) +├── task_type: str (default "rl_rollout") +├── min_success_rate, min_avg_reward (thresholds) +├── current_leader_hotkey, current_leader_reward +└── Relationships: submissions, evaluation_jobs, leader_candidates + +MinerSubmission +├── miner_hotkey, competition_id +├── hf_repo_id, version, commitment_block +├── artifact_object_key, artifact_sha256 +├── holdout_release_at, released_at +└── Relationships: evaluation_jobs + +BackendEvaluationJob +├── submission_id, competition_id +├── env_provider, benchmark_name, config +└── Relationships: results, status_updates + +BackendEvaluationResult +├── job_id, validator_hotkey +├── score, success_rate, avg_reward, total_episodes +└── Relationships: leader_candidates + +CompetitionLeaderCandidate +├── competition_id, miner_hotkey, evaluation_result_id +├── status: pending | approved | rejected +└── reviewed_by_api_key_id, reviewed_at +``` + +## Message Types (`src/core/messages.py`) + +### Backend → Evaluator (via WebSocket) +- `EvalJobMessage`: New evaluation job (includes task_type, task_spec) +- `HeartbeatAckMessage`, `RegistrationAckMessage`, `ResultAckMessage` + +### Evaluator → Backend (via WebSocket) +- `EvaluatorRegisterMessage`: Registration with API key and capabilities +- `HeartbeatMessage`: Keepalive +- `EvalResultMessage`: Completed evaluation with metrics +- `JobStatusUpdateMessage`: Status transitions + +### Backend → Validator (via REST API) +- `GET /weights`: Returns `WeightsSnapshot` with UID→weight mapping + +## Communication Patterns + +### WebSocket (Backend ↔ Evaluator) +- Direct connection for job dispatch and result collection +- Evaluators register with supported task types +- Jobs broadcast to all connected evaluators + +### HTTP Polling (Validator → Backend) +- Validator polls `GET /weights` periodically (default: 5 min) +- Sets weights on Bittensor chain when changed +- Simple polling avoids WebSocket complexity + +### RPC (Evaluator ↔ Miner Pod) +- Cap'n Proto schema (`agent.capnp`) +- Methods: `reset()`, `act(observation) -> action` +- Avoids pickle for security + +## Scoring Logic + +1. **Eligibility**: Result must meet `min_success_rate` and `min_avg_reward` +2. **Leader Candidates**: Eligible results that beat current leader are queued for admin review +3. **Admin Approval**: Candidates require manual approval to become leader +4. **Winner-Takes-All**: One winner per competition gets points +5. **Burn Mechanism**: Configurable burn_pct (default 98%), remainder to owner UID +6. **Weight Calculation**: Normalized scores across all competitions + +## Adding New Task Types + +### 1. Define TaskType +```python +# src/core/tasks.py +class TaskType(StrEnum): + RL_ROLLOUT = "rl_rollout" + TRAINING_RUN = "training_run" # Add new type +``` + +### 2. Implement TaskExecutor +```python +# src/evaluator/executors/training_executor.py +class TrainingExecutor: + task_type = TaskType.TRAINING_RUN + + async def validate_spec(self, spec: TaskSpec) -> list[str]: ... + async def setup(self, spec: TaskSpec) -> TaskContext: ... + async def execute(self, context: TaskContext) -> TaskResult: ... + async def teardown(self, context: TaskContext) -> None: ... +``` + +### 3. Register Executor +```python +# src/evaluator/executors/registry.py +ExecutorRegistry.register(TrainingExecutor(config)) +``` + +### 4. Implement ScoringStrategy +```python +# src/backend/scoring/strategies.py +class TrainingScoringStrategy: + task_type = TaskType.TRAINING_RUN + + def extract_metrics(self, result) -> ScoringMetrics: ... + def check_eligibility(self, metrics, competition) -> EligibilityResult: ... + def compute_score(self, metrics, competition) -> float: ... + def compare(self, a, b) -> int: ... +``` + +### 5. Register Strategy +```python +# src/backend/scoring/registry.py +ScoringStrategyRegistry.register(TrainingScoringStrategy()) +``` + +## Testing + +```bash +# All tests +python3 -m pytest src/backend/tests/ src/backend/scoring/tests/ \ + src/evaluator/executors/tests/ src/evaluator/providers/tests/ -v + +# Specific test files +python3 -m pytest src/backend/scoring/tests/test_strategies.py -v # Scoring strategies +python3 -m pytest src/backend/tests/test_scoring.py -v # Scoring engine +python3 -m pytest src/evaluator/executors/tests/test_registry.py -v # Executor registry +``` + +## Configuration + +### Backend (`config/backend.toml.example`) +- Database URL, WebSocket settings +- Chain sync intervals, commitment scanning + +### Validator (`config/validator.toml.example`) +- Backend URL, API key +- Wallet/hotkey configuration +- Reconnect intervals + +### Evaluator (`config/evaluator.toml.example`) +- Database, Ray cluster settings +- Worker resources, timeouts +- Kubernetes namespace + +## Deployment + +### Docker Compose (`deploy/docker/`) +```bash +docker compose up -d postgres migrator validator evaluator +``` + +### Key Environment Variables +- `KINITRO_API_KEY`: Validator authentication +- `KUBECONFIG`: Kubernetes config for evaluator +- `DB_HOST`, `DB_USER`, `DB_PASSWORD`: Database credentials + +## Development Workflow + +1. **Make changes** in appropriate module +2. **Run tests** to verify nothing broke +3. **Update migrations** if models changed (`alembic revision --autogenerate`) +4. **Test locally** with `python3 -m ` +5. **Commit** with conventional commits (feat:, fix:, refactor:, etc.) + +## Recent Refactoring (Phases 1-6) + +### Phase 1: Decomposed BackendService +- Extracted `ScoringEngine`, `ChainMonitor`, `JobScheduler` +- Created `ProviderRegistry` for environment providers + +### Phase 2: Task Abstraction Layer +- Added `TaskSpec`, `TaskResult`, `TaskContext`, `TaskExecutor` protocol +- Created `ExecutorRegistry` for task type dispatch +- Implemented `RLRolloutExecutor` wrapping existing rollout infrastructure + +### Phase 3: Scoring Strategies +- Added `ScoringStrategy` protocol for pluggable scoring +- Implemented `RLRolloutScoringStrategy` +- Created `ScoringStrategyRegistry` for strategy dispatch +- Updated `ScoringEngine` to use strategies + +### Phase 4-6: Architecture Simplification +- **Direct evaluator connection**: Evaluators connect to backend via WebSocket +- **Removed pgqueuer relay**: No more validator-as-middleman for jobs +- **Polling-based validator**: Validators poll `/weights` endpoint instead of WebSocket +- **Consolidated orchestrator**: Single `orchestrator.py` (removed legacy code) + +## Important Files Reference + +| Purpose | File | +|---------|------| +| Task interfaces | `src/core/tasks.py` | +| Executor registry | `src/evaluator/executors/registry.py` | +| RL executor | `src/evaluator/executors/rl_rollout.py` | +| Scoring engine | `src/backend/scoring_engine.py` | +| Scoring strategies | `src/backend/scoring/strategies.py` | +| Strategy registry | `src/backend/scoring/registry.py` | +| WebSocket messages | `src/core/messages.py` | +| Database models | `src/backend/models.py` | +| Backend service | `src/backend/service.py` | +| Evaluator hub | `src/backend/evaluator_hub.py` | +| Orchestrator | `src/evaluator/orchestrator.py` | diff --git a/README.md b/README.md index 44fe3f2..3c62381 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,41 @@ # Kinitro: Incentivized Embodied Intelligence -Kinitro incentivizes the emergence of agents that can conquer various tasks across different environments. Miners publish agents to compete, validators peform rollouts and evaluate the agents, and reward miners based on the results. All this happens in real-time and can easily be viewed by anyone through our [dashboard](https://kinitro.ai/dashboard). +Kinitro incentivizes the emergence of agents that can conquer various tasks across different environments. Miners publish agents to compete, evaluators perform rollouts and evaluate the agents, validators set weights on the Bittensor chain, and miners are rewarded based on the results. All this happens in real-time and can easily be viewed by anyone through our [dashboard](https://kinitro.ai/dashboard). For a visual overview of how these pieces interact, see the [architecture introduction](docs/architecture/introduction.md). +## Architecture Overview + +``` + ┌─────────────────────┐ + │ Bittensor Chain │ + │ (Commitments + │ + │ Weights) │ + └──────────┬──────────┘ + │ + ┌──────────────────────────┼──────────────────────────┐ + │ │ │ + ▼ ▼ ▼ + ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ + │ Miner │ │ Backend │ │ Validator │ + │ (Submission) │ │ (FastAPI) │ │(Polls /weights)│ + └───────┬────────┘ └───────┬────────┘ └───────┬────────┘ + │ │ │ + │ upload artifact │ GET /weights │ + └─────────────────────────►│◄─────────────────────────┘ + │ + │ WebSocket (direct) + ▼ + ┌─────────────────┐ + │ Evaluator │ + │ (Ray + K8s) │ + └─────────────────┘ +``` + +- **Backend**: Central orchestration service that manages competitions, submissions, and job scheduling +- **Evaluators**: Connect directly to the backend via WebSocket, receive jobs, and stream results back +- **Validators**: Poll the backend's `/weights` endpoint and set weights on the Bittensor chain + ## Repository Setup 1. **Clone the repository** @@ -31,15 +63,49 @@ For a visual overview of how these pieces interact, see the [architecture introd ## Project Layout -- **Backend service** (`src/backend/`): FastAPI backend with a realtime broadcaster, chain monitor, job scheduler, and scoring engine backed by PostgreSQL. -- **Validator node** (`src/validator/`): WebSocket client that authenticates with the backend, relays evaluation jobs into a persistent `pgqueuer` queue, and streams results and episode logs back. -- **Evaluator cluster** (`src/evaluator/`): Ray-powered orchestrator that spins submission pods in Minikube (Kubernetes), runs rollout workers, logs per-step data, and pushes metrics/results into the validator queue. +- **Backend service** (`src/backend/`): FastAPI backend with a realtime broadcaster, chain monitor, job scheduler, evaluator hub, and scoring engine backed by PostgreSQL. +- **Evaluator cluster** (`src/evaluator/`): Ray-powered orchestrator that connects directly to the backend via WebSocket, spins up submission pods in Kubernetes, runs rollout workers, and streams metrics/results back. +- **Validator node** (`src/validator/`): Lightweight service that polls the backend's `/weights` endpoint and sets weights on the Bittensor chain. - **Miner tooling** (`src/miner/`): CLI helpers that package models, upload artifacts to Hugging Face, and notarize submissions on the Bittensor chain. - **Shared core** (`src/core/`): Message formats, chain helpers, database models, and logging utilities that keep every component speaking the same language. - **Docs** (`docs/`): User guides and deep dives (architecture, miner, validator, overview). - **Scripts** (`scripts/`): Database migration/reset helpers and operational scripts. - **Deploy artifacts** (`deploy/`): Dockerfiles and Docker Compose stack (with Minikube integration) for running validators/evaluators with CPU or GPU workloads. +## Running Components + +### Backend +```bash +python -m backend +``` + +### Evaluator +```bash +python -m evaluator +``` + +### Validator +```bash +python -m validator +``` + +## Database Migrations + +```bash +# Backend database +./scripts/migrate_backend_db.sh + +# Evaluator database +./scripts/migrate_evaluator_db.sh +``` + +## Testing + +```bash +python -m pytest src/backend/tests/ src/backend/scoring/tests/ \ + src/evaluator/executors/tests/ src/evaluator/providers/tests/ -v +``` + Questions or ideas? Open an issue or reach out to us on [our channel in the Bittensor Discord server](https://discord.gg/96SdmpeMqG). Contributions via pull requests are welcome. ## License diff --git a/docs/architecture/evaluator.md b/docs/architecture/evaluator.md index bc26558..c2c7c88 100644 --- a/docs/architecture/evaluator.md +++ b/docs/architecture/evaluator.md @@ -1,33 +1,40 @@ # Evaluator -The evaluator executes benchmarks for each miner submission. It is responsible for spawning rollout workers, connecting them to containers which run agents, and collecting metrics that flow back to the backend through the validator. +The evaluator executes benchmarks for each miner submission. It is responsible for spawning rollout workers, connecting them to containers which run agents, and collecting metrics that stream directly back to the backend. ## Runtime Components -- **Rollout cluster** – `RolloutCluster` manages Ray actors and creates workers that can process multiple benchmark specs. -- **Benchmark specs** – Each evaluation job declares an environment provider, benchmark name, and config that turn into `BenchmarkSpec` objects. -- **Rollout worker** – Performs the actual environment interaction, calls the miner agent over RPC, tracks rewards, and records key statistics. -- **RPC bridge** – A lightweight process that proxies actions and observations between the worker and the submission container over TCP. -- **Episode logger** – Streams step data, episode summaries, and uploaded artifacts to pgqueuer while handling retries and back-pressure. +- **Rollout cluster** - `RolloutCluster` manages Ray actors and creates workers that can process multiple benchmark specs. +- **Benchmark specs** - Each evaluation job declares an environment provider, benchmark name, and config that turn into `BenchmarkSpec` objects. +- **Rollout worker** - Performs the actual environment interaction, calls the miner agent over RPC, tracks rewards, and records key statistics. +- **RPC bridge** - A lightweight process that proxies actions and observations between the worker and the submission container over TCP. +- **Episode logger** - Streams step data, episode summaries, and uploaded artifacts back to the backend via WebSocket. ## Episode Execution Flow -1. The orchestrator builds a `BenchmarkSpec` from the evaluation job and creates a rollout worker. -2. When the worker requests an action, it forwards the observation to the submission container via the RPC bridge. -3. The container invokes the miner-provided policy, returns actions, and the worker steps the environment. -4. The episode logger records rewards, success flags, uploaded observation images, and any extra metrics and throttles writes according to `episode_log_interval` and `step_log_interval`. -5. Episode completions trigger an `EpisodeDataMessage`; step-level logs trigger `EpisodeStepDataMessage`s. Both messages are persisted through pgqueuer for the validator to publish. -6. Once all benchmarks finish, the worker assembles aggregate metrics (success rate, average reward, total episodes) that feed into the orchestrator’s result payload. +1. The orchestrator receives an `EvalJobMessage` from the backend via WebSocket and builds a `BenchmarkSpec`. +2. A rollout worker is created to execute the benchmark. +3. When the worker requests an action, it forwards the observation to the submission container via the RPC bridge. +4. The container invokes the miner-provided policy, returns actions, and the worker steps the environment. +5. The episode logger records rewards, success flags, uploaded observation images, and any extra metrics. Logging frequency is controlled by `episode_log_interval` and `step_log_interval`. +6. Episode completions trigger telemetry that streams directly to the backend over the WebSocket connection. +7. Once all benchmarks finish, the worker assembles aggregate metrics (success rate, average reward, total episodes) that feed into the orchestrator's `EvalResultMessage`. ## Storage and Artifacts -- **S3 Storage** – Image observations and other heavy artifacts upload through the logger’s executor so the backend can serve signed URLs later. -- **Validator database** – Results and telemetry are queued to the validator Postgres instance so that validator and backend connectivity issues do not drop data. +- **S3 Storage** - Image observations and other heavy artifacts upload through the logger's executor so the backend can serve signed URLs later. +- **Direct streaming** - Results and telemetry stream directly to the backend over WebSocket, eliminating intermediate queues. ## Configuration Highlights -- `episode_log_interval` and `step_log_interval` control how frequently detailed telemetry is enqueued (`config/evaluator.toml.example`). -- `max_concurrent_jobs` guards resource usage when multiple evaluations are queued concurrently. -- `s3_config` includes bucket credentials used for artifact uploads. +Key settings in `evaluator.toml`: + +- `backend_url` - WebSocket URL for the backend (e.g., `wss://api.kinitro.ai/ws/evaluator`). +- `api_key` - Authentication key for the evaluator. +- `episode_log_interval` and `step_log_interval` - Control how frequently detailed telemetry is streamed. +- `max_concurrent_jobs` - Guards resource usage when multiple evaluations run concurrently. +- `s3_config` - Bucket credentials used for artifact uploads. +- `ray_num_cpus`, `ray_num_gpus`, `ray_memory_gb` - Ray head resources. +- `worker_num_cpus`, `worker_num_gpus`, `worker_memory_gb` - Per-worker resource requests. See the [orchestrator guide](orchestrator.md) for how jobs arrive and are scheduled. diff --git a/docs/architecture/incentive.md b/docs/architecture/incentive.md index bef0895..ceb1bf4 100644 --- a/docs/architecture/incentive.md +++ b/docs/architecture/incentive.md @@ -39,11 +39,12 @@ and stamp approval metadata"] Administrators can annotate approvals or rejections, and the backend records the reviewer, reason, and timestamp. Pending candidates remain in the queue until a decision is made, making it easy to compare multiple challengers side-by-side. -## Weight Broadcasting +## Weight Distribution -1. **WebSocket broadcast** – The backend pushes weight messages to every connected validator. -2. **Chain update** – Validators commit weights on the Bittensor chain. -3. **Cache warmup** – The backend keeps a copy of the latest weights in memory so reconnecting validators receive immediate updates even before the next scoring cycle. Competitions without an approved leader simply contribute zero weight until an approval occurs. +1. **HTTP endpoint** - The backend serves weight snapshots via `GET /weights`, which returns a `WeightsSnapshot` containing UID-to-weight mappings. +2. **Validator polling** - Validators periodically poll this endpoint (default: every 5 minutes) to fetch the latest weights. +3. **Chain update** - When weights change, validators call `set_weights` on the Bittensor chain. +4. **Cache warmup** - The backend keeps a copy of the latest weights in memory for fast retrieval. Competitions without an approved leader simply contribute zero weight until an approval occurs. ## Configuration @@ -56,4 +57,4 @@ Administrators can annotate approvals or rejections, and the backend records the - Scores depend on trusted evaluation results. Validators should ensure evaluators are running the same container images and configuration to avoid inconsistent outcomes. - When no miner satisfies a competition’s success criteria, the backend skips weight updates for that competition and previously broadcast weights decay to zero. -- Validators without an active API key will fail to receive weight updates. +- Validators poll the public `/weights` endpoint, which does not require authentication. diff --git a/docs/architecture/introduction.md b/docs/architecture/introduction.md index ee9ae33..5762bd5 100644 --- a/docs/architecture/introduction.md +++ b/docs/architecture/introduction.md @@ -10,9 +10,9 @@ Kinitro incentivizes the emergence of agents that can conquer various tasks acro 1. **Submission** – The miner CLI packages the agent, requests a presigned upload slot from the backend, pushes the archive directly to the private vault, and commits the submission ID on Bittensor. 2. **Ingestion** – The backend ties the chain commitment to the uploaded artifact, records the submission with its hold-out window, and schedules evaluation jobs. -3. **Distribution** – Validators connect via WebSocket, receive jobs (including signed artifact URLs) and queue them for execution. +3. **Distribution** – Evaluators connect directly to the backend via WebSocket, receive jobs (including signed artifact URLs) and execute them. 4. **Evaluation** – The orchestrator launches a Kubernetes pod per submission, Ray rollout workers evaluate the agent via RPC, and telemetry is logged. -5. **Results & Incentives** – Validators stream results back to the backend, which stores metrics, queues leader candidates for admin approval, emits realtime updates, and recalculates scores/weights once an approved leader exists. After hold-out expiry, the backend issues time-limited release URLs for public access. +5. **Results & Incentives** – Evaluators stream results back to the backend, which stores metrics, queues leader candidates for admin approval, emits realtime updates, and recalculates scores/weights once an approved leader exists. Validators poll the `/weights` endpoint and set weights on the Bittensor chain. After hold-out expiry, the backend issues time-limited release URLs for public access. ## System Architecture @@ -48,17 +48,18 @@ flowchart TD MIN -- "upload artifact" --> S3 MIN -- "commit submission id" --> BT BEC -- "monitor commitments" --> BT - BEC -- "link upload & create jobs" --> VALN - %% Evaluation loop - VALN -- "dispatch eval" --> EVN + %% Evaluation loop (direct backend-evaluator connection) + BEC -- "WebSocket: dispatch jobs" --> EVN EVN -- "download via presigned GET" --> S3 - EVN -- "results" --> VALN - VALN -- "report results" --> BEC + EVN -- "WebSocket: stream results" --> BEC + + %% Weight setting (validator polls backend) + VALN -- "HTTP: poll /weights" --> BEC + VALN -- "set weights" --> BT %% Outputs BEC -- "broadcast updates" --> CLN - BEC -- "weight updates" --> BT BEC -- "release presigned URL on hold-out expiry" --> CLN %% Styles @@ -78,8 +79,9 @@ sequenceDiagram participant BackendAPI participant SubmissionStorage as S3 Vault participant Chain - participant ValidatorOrchestrator as Validator Orchestrator + participant EvaluatorCluster as Evaluator participant K8sPod as Evaluation Pod + participant ValidatorNode as Validator participant AdminConsole as Admin Console participant ReleaseTask as Hold-out Release Task @@ -89,17 +91,23 @@ sequenceDiagram MinerCLI->>Chain: Commit (provider=S3, submission_id, comp_id) BackendAPI->>BackendAPI: Match commitment with upload
create MinerSubmission + jobs - BackendAPI->>ValidatorOrchestrator: Broadcast EvalJobMessage (artifact URL, hash, holdout info) + BackendAPI->>EvaluatorCluster: WebSocket: Broadcast EvalJobMessage (artifact URL, hash, holdout info) - ValidatorOrchestrator->>K8sPod: Launch pod (init + runner) + EvaluatorCluster->>K8sPod: Launch pod (init + runner) K8sPod->>SubmissionStorage: GET submission.tar.gz (presigned) - K8sPod->>ValidatorOrchestrator: RPC evaluation results - ValidatorOrchestrator->>BackendAPI: EvalResultMessage & status updates + K8sPod->>EvaluatorCluster: RPC evaluation results + EvaluatorCluster->>BackendAPI: WebSocket: EvalResultMessage & status updates BackendAPI->>BackendAPI: Store metrics, update job status BackendAPI->>AdminConsole: Surface pending leader candidates AdminConsole->>BackendAPI: Approve / reject candidate (optional note) BackendAPI->>BackendAPI: Apply approved leader + cache scores + loop periodic (every 5 min) + ValidatorNode->>BackendAPI: GET /weights + BackendAPI-->>ValidatorNode: WeightsSnapshot + ValidatorNode->>Chain: set_weights (if changed) + end + loop periodic ReleaseTask->>BackendAPI: Scan for expired hold-outs BackendAPI->>SubmissionStorage: Presign release GET URL @@ -112,25 +120,26 @@ sequenceDiagram ### Backend Service -- **FastAPI REST / Admin**: Hosts competition CRUD, submission uploads, stats, validator management, and WebSocket endpoints. +- **FastAPI REST / Admin**: Hosts competition CRUD, submission uploads, stats, and admin endpoints. +- **Evaluator Hub**: Manages WebSocket connections from evaluators via `/ws/evaluator`, dispatches jobs, and receives results. - **Chain Monitor & Scheduler**: Tracks Bittensor commitments, ties them to uploaded artifacts, creates `BackendEvaluationJob` records, and watches for stale work. - **Hold-out & Vault Manager**: Issues presigned URLs for uploads and releases, enforces per-competition hold-out windows, and keeps artifacts private until expiry. - **Realtime Broadcaster**: Manages client subscriptions and pushes structured events such as job updates, episode completions, and live stats. -- **Scoring & Weight Engine**: Periodically recalculates miner scores and pushes weight updates back to validators for on-chain emission. -- **Backend PostgreSQL**: Source of truth for competitions, submissions, jobs, job status, results, stats, and validator connections. +- **Scoring & Weight Engine**: Periodically recalculates miner scores and serves weight snapshots via the `/weights` endpoint for validators to poll. +- **Backend PostgreSQL**: Source of truth for competitions, submissions, jobs, job status, results, stats, and evaluator connections. ### Validator Node -- **WebSocket Client**: Authenticates with the backend, receives `EvalJobMessage` payloads, and streams results back. -- **pgqueuer Runner**: Persists jobs/results/episode logs in PostgreSQL so work survives restarts and can be retried. -- **Validator PostgreSQL**: Stores pgq queues plus normalized tables for jobs, results, and metrics consumed by the evaluator. +- **HTTP Poller**: Periodically fetches `GET /weights` from the backend to retrieve the latest weight snapshot. +- **Weight Setter**: Compares fetched weights against the last committed values and calls `set_weights` on the Bittensor chain when changes occur. +- **Lightweight Design**: No database, no evaluator, no WebSocket - just polls and sets weights. ### Evaluator Cluster -- **Evaluator Orchestrator**: Listens to the pgqueuer queue, enforces concurrency caps, and coordinates job lifecycles. +- **Evaluator Orchestrator**: Connects directly to the backend via WebSocket, receives jobs, enforces concurrency caps, and coordinates job lifecycles. - **Submission Pods**: Kubernetes pods created per submission to run miner containers in isolation. - **Ray Rollout Workers**: Execute benchmark episodes, communicate with submission pods via RPC, and track success metrics. -- **Episode Logger**: Captures per-episode and per-step data, uploads media to S3-compatible storage, and enqueues telemetry for validator forwarding. +- **Episode Logger**: Captures per-episode and per-step data, uploads media to S3-compatible storage, and streams telemetry back to the backend. ### Miner Tooling @@ -138,10 +147,10 @@ sequenceDiagram ### Real-time Clients -- Subscribe to the backend’s public WebSocket endpoint to monitor competitions, validator connectivity, and evaluation progress live. +- Subscribe to the backend's public WebSocket endpoint to monitor competitions, evaluator connectivity, and evaluation progress live. ## Next Steps -- Dive into the [Validator architecture notes](orchestrator.md) to see how the queue, database, and message formats interact. +- Dive into the [Evaluator architecture notes](orchestrator.md) to see how jobs flow from backend to evaluators. - Review the [Evaluator internals](evaluator.md) for details on Ray workers, RPC bridges, and logging pipelines. - Check the [Incentive mechanism](incentive.md) to understand how scores flow into weight updates. diff --git a/docs/architecture/orchestrator.md b/docs/architecture/orchestrator.md index 779630e..769f5ba 100644 --- a/docs/architecture/orchestrator.md +++ b/docs/architecture/orchestrator.md @@ -1,31 +1,75 @@ # Orchestrator -The orchestrator is the control plane that turns queued evaluation jobs into running rollouts. It listens to pgqueuer events, provisions isolated containers which run agents, wires up Ray workers, and makes sure results flow back to the backend through the validator. +The orchestrator is the control plane that turns evaluation jobs into running rollouts. It connects directly to the backend via WebSocket, receives job assignments, provisions isolated containers which run agents, wires up Ray workers, and streams results back to the backend. ## Responsibilities -- **Queue consumption** – `PgQueuer` watches the validator database and invokes the orchestrator whenever a new `add_job` event appears. -- **Concurrency control** – Track active jobs and defer new work until there is capacity. -- **Environment provisioning** – Spin up Kubernetes pods that host the miner submission and expose an RPC endpoint for workers. -- **Worker orchestration** – Create Ray rollout workers, attach benchmark specs, and stream observations and other metrics back. -- **Result collation** – Collect evaluation summaries, persist them via `EvalResultMessage`, and queue them for validator delivery. -- **Cleanup & recovery** – Tear down pods, close queues, and reclaim Ray resources even on failure. +- **Backend connection** – Maintains a persistent WebSocket connection to the backend's `/ws/evaluator` endpoint. Authenticates with an API key and receives job broadcasts. +- **Concurrency control** – Tracks active jobs and defers new work until there is capacity. +- **Environment provisioning** – Spins up Kubernetes pods that host the miner submission and expose an RPC endpoint for workers. +- **Worker orchestration** – Creates Ray rollout workers, attaches benchmark specs, and streams observations and other metrics back. +- **Result streaming** – Sends `EvalResultMessage` and `JobStatusUpdateMessage` payloads directly to the backend as work completes. +- **Cleanup & recovery** – Tears down pods, closes connections, and reclaims Ray resources even on failure. ## Job Lifecycle -1. The validator enqueues a job from the backend. -2. `PgQueuer` triggers the orchestrator’s `process` handler with the job payload. -3. The job is recorded in the validator database with `EvaluationStatus.STARTING`, guaranteeing visibility and retries. -4. A submission container is created via the `Containers` helper, exposing a service the worker can reach over TCP. -5. A `RolloutCluster` worker runs the benchmark episodes, talking to the submission container through an RPC bridge. -6. Episode-level and step-level telemetry is queued by the `EpisodeLogger`, which uses the same pgqueuer channel so the validator can forward data to the backend. -7. When the rollout completes, scores and aggregates queued for delivery. -8. Cleanup routines ensure pods are removed, Ray actors stopped, and lingering queues closed. +1. The backend broadcasts an `EvalJobMessage` to all connected evaluators via WebSocket. +2. The orchestrator receives the job and records it locally with `EvaluationStatus.STARTING`. +3. A submission container is created via the `Containers` helper, exposing a service the worker can reach over TCP. +4. A `RolloutCluster` worker runs the benchmark episodes, talking to the submission container through an RPC bridge. +5. Episode-level and step-level telemetry is captured by the `EpisodeLogger` and streamed back to the backend. +6. When the rollout completes, the orchestrator sends an `EvalResultMessage` with scores and aggregates. +7. Cleanup routines ensure pods are removed, Ray actors stopped, and resources reclaimed. + +## Communication Flow + +```mermaid +sequenceDiagram + participant Backend + participant Orchestrator + participant K8sPod as Submission Pod + participant RayWorker as Ray Worker + + Orchestrator->>Backend: WebSocket: Register (API key, capabilities) + Backend-->>Orchestrator: RegistrationAck + + Backend->>Orchestrator: EvalJobMessage (artifact URL, benchmark spec) + Orchestrator->>Backend: JobStatusUpdate (STARTING) + + Orchestrator->>K8sPod: Create pod + K8sPod-->>Orchestrator: Pod ready + + Orchestrator->>RayWorker: Start rollout + RayWorker->>K8sPod: RPC: act(observation) + K8sPod-->>RayWorker: action + + loop episodes + RayWorker->>Orchestrator: Episode metrics + Orchestrator->>Backend: Telemetry updates + end + + RayWorker-->>Orchestrator: Rollout complete + Orchestrator->>Backend: EvalResultMessage (scores) + Orchestrator->>Backend: JobStatusUpdate (COMPLETED) + + Orchestrator->>K8sPod: Delete pod +``` ## Resilience Features -- **Durable queues** – All commands and results flow through pgq tables, so reconnecting services pick up exactly where they left off. +- **Automatic reconnection** – If the WebSocket connection drops, the orchestrator reconnects with exponential backoff. +- **Job recovery** – Jobs in progress when a disconnect occurs can be resumed or re-queued by the backend. - **Timeout handling** – The orchestrator checks elapsed time per job and can mark stale work for cleanup. - **Health monitoring** – Background tasks watch running jobs for completion or timeout signals and remove them when necessary. +## Configuration Highlights + +Key settings in `evaluator.toml`: + +- `backend_url` – WebSocket URL for the backend (e.g., `wss://api.kinitro.ai/ws/evaluator`). +- `api_key` – Authentication key for the evaluator. +- `max_concurrent_jobs` – Maximum number of parallel evaluations. +- `ray_num_cpus`, `ray_num_gpus`, `ray_memory_gb` – Ray head resources. +- `worker_num_cpus`, `worker_num_gpus`, `worker_memory_gb` – Per-worker resource requests. + Refer to the [Evaluator internals](evaluator.md) for details on the worker side of this pipeline. diff --git a/docs/overview.mdx b/docs/overview.mdx index 4808e26..b8c6bec 100644 --- a/docs/overview.mdx +++ b/docs/overview.mdx @@ -31,8 +31,8 @@ After setting up the repository locally, the next steps will depend on what role ## Platform Roles - **Miners** upload agent artifacts to our backend and submit a signed commitment to the Bittensor chain. Use the [miner guide](./miner.md) to configure packaging, uploading, and on-chain commits. -- **Validators** maintain a websocket connection to the backend, persist jobs in pgqueuer, and relay results and episode logs. Follow the [validator guide](./validator/validator.md) for environment setup and operations. -- **Evaluator operators** run the Ray orchestrator that spins up submission containers, executes benchmarks, and logs telemetry. See the [evaluator architecture notes](./architecture/evaluator.md) for details. +- **Validators** poll the backend's `/weights` endpoint and set weights on the Bittensor chain. Follow the [validator guide](./validator/validator.md) for environment setup and operations. +- **Evaluator operators** run the Ray orchestrator that connects directly to the backend, spins up submission containers, executes benchmarks, and streams telemetry. See the [evaluator architecture notes](./architecture/evaluator.md) for details. - **Backend operators** deploy the FastAPI service that monitors the chain, schedules jobs, emits realtime updates, and approve pending leader candidates before new rewards go live. ## Learn the Architecture diff --git a/docs/validator/docker.md b/docs/validator/docker.md index c079dc6..fa985cc 100644 --- a/docs/validator/docker.md +++ b/docs/validator/docker.md @@ -6,12 +6,12 @@ section: 'Start Validating' ## Images -- `validator.dockerfile` – WebSocket validator service (`ghcr.io/threetau/kinitro-validator`). +- `validator.dockerfile` – Lightweight weight-setting validator (`ghcr.io/threetau/kinitro-validator`). - `evaluator.dockerfile` – CPU evaluator orchestrator (`ghcr.io/threetau/kinitro-evaluator`). - `evaluator-cuda.dockerfile` – CUDA-enabled evaluator (`ghcr.io/threetau/kinitro-evaluator:*-gpu`). - `miner-agent.dockerfile` – Submission runtime image used by evaluator-created pods. - `miner-agent-cuda.dockerfile` – GPU variant of the submission runtime. -- `migrator.dockerfile` – Alembic/pgq migration job image (locally tagged as `kinitro-migrator`). +- `migrator.dockerfile` – Alembic migration job image for the backend database (locally tagged as `kinitro-migrator`). Shared entrypoints live beside the Dockerfiles: `entrypoint-validator.sh`, `entrypoint-evaluator.sh`. diff --git a/docs/validator/validator.md b/docs/validator/validator.md index fe8d37f..8e6386d 100644 --- a/docs/validator/validator.md +++ b/docs/validator/validator.md @@ -4,17 +4,19 @@ section: 'Start Validating' # Validator -Validators are responsible for evaluating the performance of miner-submitted agents on a variety of tasks. +Validators are responsible for setting weights on the Bittensor chain based on miner performance. The validator periodically polls the backend's `/weights` endpoint and commits any changes on-chain. -**Choose your validator type first:** +## Architecture Overview -1. [Full Validator (full pipeline)](#full-validator-full-pipeline) – runs evaluations, streams logs, and requires the evaluator + Postgres stack. -2. [Lite Validator (HTTP weight setter)](#lite-validator-http-weight-setter) – polls the public weight endpoint and only sets weights on-chain. +The validator is intentionally lightweight: -After picking the implementation, choose how you want to deploy it: +``` +Backend ─────GET /weights────► Validator ─────set_weights────► Bittensor Chain +``` -1. [Bare Metal](#setup---bare-metal) -2. [Containerized deployment](#setup---containerized-deployment) +- **No WebSocket connection** - The validator uses simple HTTP polling +- **No database required** - Weights are fetched fresh each cycle +- **No evaluator needed** - Evaluators connect directly to the backend (see [Evaluator docs](../architecture/evaluator.md)) ## Setup - Bare Metal @@ -26,9 +28,7 @@ Copy the `.env.validator.example` file to `.env` and fill in the required enviro cp .env.validator.example .env ``` -You will need to create an R2 bucket and set the relevant environment variables. This is required for storing some evaluation data. For more information please refer to Cloudflare's [R2 documentation](https://developers.cloudflare.com/r2/buckets/). - -If you are running a Full Validator (*not* a lite validator), You will need to set `KINITRO_API_KEY` to obtain access to the Kinitro backend. Please contact us on our [discord channel](https://discord.gg/96SdmpeMqG) for access. +The only required environment variables are for your Bittensor wallet configuration. ### Configuration @@ -38,162 +38,149 @@ To configure a validator, start by copying the example configuration file: cp config/validator.toml.example validator.toml ``` -The example config now includes both the Full validator and the lite HTTP-based weight setter. Core knobs look like: +Key configuration options: ```toml -validator_mode = "full" # switch to "lite" to run the HTTP weight setter +# Backend endpoint to poll for weights weights_url = "https://api.kinitro.ai/weights" -weights_poll_interval = 30.0 -weights_request_timeout = 10.0 -weights_stale_threshold = 180.0 -``` -Use the default `weights_url` unless you operate your own backend; the lite mode polls this endpoint and pushes updates on-chain. +# How often to poll for weight updates (seconds) +weights_poll_interval = 300.0 # 5 minutes -#### Full validator (full pipeline) - -Set `validator_mode = "full"` to run the full evaluator pipeline. This mode: - -- maintains a WebSocket connection to the backend for job distribution and telemetry, -- requires the PostgreSQL queue (`pg_database`) and evaluator service, -- forwards evaluation results back to the backend. +# HTTP request timeout (seconds) +weights_request_timeout = 10.0 -You will also need `evaluator.toml` for the orchestrator that executes jobs: +# How old weights can be before considered stale (seconds) +weights_stale_threshold = 900.0 # 15 minutes -```bash -cp config/evaluator.toml.example evaluator.toml +# Bittensor configuration +netuid = 123 +network = "finney" # or "test" for testnet +wallet_name = "default" +hotkey_name = "default" ``` -Edit `evaluator.toml` to set your desired parameters, such as the PostgreSQL database connection string, R2 credentials, logging intervals, and the `log_file` path (defaults to `logs/evaluator.log`) where the orchestrator persists its stdout/stderr stream. The backend FastAPI service reads the same `log_file` key from `backend.toml`, so backend requests and background worker logs land in `logs/backend.log` by default. +Use the default `weights_url` unless you operate your own backend. -Key resource knobs in `evaluator.toml`: +### Running the validator -- `log_file` – path on disk where evaluator logs are mirrored in addition to stdout. -- `ray_num_cpus`, `ray_num_gpus`, `ray_memory_gb`, `ray_object_store_memory_gb` – tune the Ray head resources the orchestrator reserves when it boots. -- `worker_num_cpus`, `worker_num_gpus`, `worker_memory_gb`, `worker_max_restarts`, `worker_max_task_retries` – control how much CPU/GPU/memory each rollout worker actor requests from Ray. +Launch the validator with: -#### Lite validator (HTTP weight setter) +```bash +python -m validator --config validator.toml +``` -Set `validator_mode = "lite"` when you only need to mirror backend weight decisions on-chain. This mode: +The validator will: +1. Poll `GET /weights` at the configured interval +2. Compare fetched weights against the last committed values +3. Call `set_weights` on the Bittensor chain when changes occur +4. Log all weight-setting activity -- polls `weights_url` over HTTPS for the latest snapshot, -- uses your Bittensor wallet/hotkey to submit weights, -- does **not** require the evaluator service or Postgres queue. +### Wallet Setup -You still need valid wallet credentials and chain connectivity, but no backend API key is required because the `/weights` endpoint is public. +Ensure your Bittensor wallet is properly configured: -### Setting up database +```bash +# Check wallet exists +btcli wallet list -The Full validator requires a PostgreSQL database for queuing evaluation jobs and results. The lite validator can skip this section. +# Register on subnet if needed +btcli subnet register --netuid --wallet.name --wallet.hotkey +``` -To set up the database, you can either: +The validator needs sufficient stake to set weights on the subnet. -1. **Reset the database** (drops and recreates the database with all migrations): +## Setup - Containerized Deployment - ```bash - chmod +x ./scripts/reset_validator_db.sh - ./scripts/reset_validator_db.sh - ``` +We ship Docker recipes for the validator in `deploy/docker/`. The validator container is lightweight and only requires network access to the backend and Bittensor chain. -2. **Run migrations only** (on an existing database): +### 1. Prerequisites - ```bash - chmod +x ./scripts/migrate_validator_db.sh - ./scripts/migrate_validator_db.sh - ``` +- **Docker Compose v2** (bundled with modern Docker releases) +- **Bittensor wallets** - Point `BITTENSOR_HOME` at your wallet directory (defaults to `$HOME/.bittensor`) -The migration script will check if the database exists and run Alembic migrations to bring it up to date. It will also ensure the pgq extension is installed if needed. +```bash +export BITTENSOR_HOME="$HOME/.bittensor" +``` -### Running the validator +### 2. Prepare configuration files -Regardless of mode, launch the process with: +Copy your configuration into the Compose folder: ```bash -python -m validator --config validator.toml +mkdir -p deploy/docker/validator-config +cp validator.toml deploy/docker/validator-config/ ``` -- With `validator_mode = "full"` the service opens the backend WebSocket and requires the evaluator plus database to be running. -- With `validator_mode = "lite"` the service polls `/weights` and immediately applies updates on-chain. No evaluator or database is needed. - -### Running the Evaluator - -Only required for the Full validator. Start it once your validator is up: +### 3. Run the validator ```bash -python -m evaluator.orchestrator --config evaluator.toml +docker compose -f deploy/docker/compose.yaml up -d validator ``` -## Setup - Containerized deployment +The validator container mounts your wallet directory read-only and uses the configuration from `validator-config/`. -We ship Docker recipes for the validator stack in `deploy/docker/`. The workflow below covers both CPU-only and GPU-enabled setups. The provided Compose profiles target the full validator; the lite validator can be run as a lightweight bare-metal process alongside the stack if desired. - -### 1. Prerequisites +## Running an Evaluator (Optional) -- **Docker Compose v2** (bundled with modern Docker releases). -- **Environment variables** – export `KINITRO_API_KEY` in your shell or place it in `deploy/docker/validator-config/.env` before you launch the stack. -- **Bittensor wallets** – point `BITTENSOR_HOME` at your wallet directory (defaults to `$HOME/.bittensor`). For example: +If you want to contribute to the evaluation network, you can run an evaluator separately. Evaluators connect directly to the backend via WebSocket and do not require the validator. - ```bash - export KINITRO_API_KEY=xxxxxxxx - export BITTENSOR_HOME="$HOME/.bittensor" - ``` +See the [Evaluator documentation](../architecture/evaluator.md) for setup instructions. -- **GPU hosts only** – install the NVIDIA Container Toolkit (see the [official guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)). -- **Minikube (optional)** – required when you plan to run GPU evaluations; install it and start with GPU support as described in the [Minikube start documentation](https://minikube.sigs.k8s.io/docs/start/?arch=%2Fmacos%2Farm64%2Fstable%2Fbinary+download). +### Evaluator Prerequisites -### 2. Prepare configuration files +Running an evaluator requires: +- **Kubernetes cluster** (Minikube for local development, or a managed K8s service) +- **Ray cluster** for distributed rollouts +- **API key** from the Kinitro team (contact us on [Discord](https://discord.gg/96SdmpeMqG)) -Copy your bare-metal configs into the Compose folder so the containers mount them read-only: +### Evaluator Configuration ```bash -mkdir -p deploy/docker/validator-config deploy/docker/evaluator-config -cp validator.toml deploy/docker/validator-config/ -cp evaluator.toml deploy/docker/evaluator-config/ -cp .env deploy/docker/.env # optional, keeps secrets out of the compose file +cp config/evaluator.toml.example evaluator.toml ``` -Keep `evaluator.toml` in sync with the resource hints you need (CPU/GPU counts, worker memory, etc.). The container reads these files from `/etc/kinitro` at runtime. +Key settings in `evaluator.toml`: -### 3. Run the CPU evaluator stack +- `backend_url` - WebSocket URL for the backend +- `api_key` - Your evaluator API key +- `max_concurrent_jobs` - Number of parallel evaluations +- Ray and worker resource settings -The CPU evaluator lives in the `cpu` profile, so it will only start if you ask for it. Bring up the base services (Postgres, validator, watchtower) and the CPU evaluator: +### Running the Evaluator ```bash -docker compose -f deploy/docker/compose.yaml up -d postgres validator watchtower -docker compose -f deploy/docker/compose.yaml --profile cpu up -d evaluator +python -m evaluator --config evaluator.toml ``` -Use `scripts/update_validator.sh` to pull new images, apply migrations with the `migrator` profile, and restart the services automatically. Set `USE_GPU_EVALUATOR=1` when you want the helper script to restart the GPU profile instead of the CPU evaluator: +## Image Matrix -```bash -./scripts/update_validator.sh # CPU stack -USE_GPU_EVALUATOR=1 ./scripts/update_validator.sh # GPU stack -``` - -### 4. Run the GPU evaluator (Minikube + CUDA) +| Image | Purpose | +| --- | --- | +| `ghcr.io/threetau/kinitro-validator` | Weight-setting validator service | +| `ghcr.io/threetau/kinitro-evaluator` | Evaluation orchestrator (CPU / `-gpu`) | +| `ghcr.io/threetau/kinitro-miner-agent` | Submission runtime for evaluation pods (CPU / `-gpu`) | -1. Start Minikube with GPU support so the evaluator can create submission pods that request GPUs: +## Troubleshooting - ```bash - minikube start --driver=docker --gpu - ``` +### Validator not setting weights - This also creates the external Docker network named `minikube`, which the evaluator containers join for API access. -2. Launch the GPU evaluator profile (CPU evaluator stays off unless you start the `cpu` profile): +1. Check wallet registration: `btcli subnet list --netuid ` +2. Verify sufficient stake for weight setting +3. Check network connectivity to `weights_url` +4. Review logs for HTTP errors - ```bash - docker compose -f deploy/docker/compose.yaml --profile gpu --compatibility up -d evaluator-gpu - ``` +### Stale weights warning -If you prefer to keep both profiles running, bring up each profile explicitly (`--profile cpu up -d evaluator` and `--profile gpu up -d evaluator-gpu`). +If weights are older than `weights_stale_threshold`, the validator logs a warning. This typically means: +- The backend is not receiving evaluation results +- No approved leaders exist for competitions +- Network issues between validator and backend -### Image matrix +### Connection errors -| Image | Purpose | Variant | -| --- | --- | --- | -| `ghcr.io/threetau/kinitro-validator` | Full validator service | CPU | -| `ghcr.io/threetau/kinitro-evaluator` | Orchestrator & Ray rollout workers | CPU / `-gpu` | -| `ghcr.io/threetau/kinitro-miner-agent` | Submission runtime for evaluator-launched pods (Minikube) | CPU / `-gpu` | -| `kinitro-migrator` (local build) | Alembic + pgq migrations | CPU | +```bash +# Test connectivity to backend +curl -s https://api.kinitro.ai/weights | jq . +``` -For local development or private registries, use the Docker Compose `build` targets to push images to your infrastructure. Update `deploy/docker/compose.yaml` to point at your registry/tag naming scheme. +If the endpoint returns a valid JSON response with weights, the issue is likely with your local configuration. diff --git a/pyproject.toml b/pyproject.toml index 120595a..6d6cf38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "typer>=0.9.0", "uvicorn>=0.24.0", "websockets>=12.0", + "pytest-asyncio>=1.3.0", ] [tool.setuptools.dynamic] diff --git a/scripts/migrate_validator_db.sh b/scripts/migrate_evaluator_db.sh similarity index 62% rename from scripts/migrate_validator_db.sh rename to scripts/migrate_evaluator_db.sh index 0458a60..1fbfe53 100755 --- a/scripts/migrate_validator_db.sh +++ b/scripts/migrate_evaluator_db.sh @@ -6,7 +6,7 @@ DB_USER="${DB_USER:-myuser}" DB_PASSWORD="${DB_PASSWORD:-}" # OK if blank when using peer/.pgpass auth DB_HOST="${DB_HOST:-localhost}" DB_PORT="${DB_PORT:-5432}" -DB_NAME="${DB_NAME:-validatordb}" +DB_NAME="${DB_NAME:-evaluatordb}" PGHOST="${DB_HOST}" PGPORT="${DB_PORT}" @@ -23,12 +23,12 @@ cd "$REPO_ROOT" # --- Check if database exists --- echo "Checking if database '${DB_NAME}' exists..." if ! psql -h "${DB_HOST}" -p "${DB_PORT}" -U "${DB_USER}" -lqt | cut -d \| -f 1 | grep -qw "${DB_NAME}"; then - echo "❌ Database '${DB_NAME}' does not exist. Please create it first or run reset_validator_db.sh" + echo "❌ Database '${DB_NAME}' does not exist. Please create it first." exit 1 fi -# --- Run migrations in validator --- -cd src/validator +# --- Run migrations in evaluator --- +cd src/evaluator # Alembic usually reads from env; set it explicitly to be safe. export DATABASE_URL="postgresql://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME}" @@ -36,13 +36,4 @@ export DATABASE_URL="postgresql://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT echo "Running Alembic migrations with uv…" uv run alembic upgrade head -# --- Check and install pgq extension if needed --- -echo "Checking pgq extension…" -if ! psql -h "${DB_HOST}" -p "${DB_PORT}" -U "${DB_USER}" -d "${DB_NAME}" -tc "SELECT 1 FROM pg_extension WHERE extname='pgq'" | grep -q 1; then - echo "Installing pgq extension…" - uv run pgq install --dry-run | psql -h "${DB_HOST}" -p "${DB_PORT}" -U "${DB_USER}" -d "${DB_NAME}" -else - echo "pgq extension already installed." -fi - -echo "✅ Validator database migrations completed successfully." \ No newline at end of file +echo "✅ Evaluator database migrations completed successfully." diff --git a/src/backend/auth.py b/src/backend/auth.py index f2635ce..fc98719 100644 --- a/src/backend/auth.py +++ b/src/backend/auth.py @@ -18,6 +18,7 @@ class UserRole(str, Enum): ADMIN = "admin" VALIDATOR = "validator" + EVALUATOR = "evaluator" # For direct evaluator connections VIEWER = "viewer" diff --git a/src/backend/chain_monitor.py b/src/backend/chain_monitor.py new file mode 100644 index 0000000..ed1f2ae --- /dev/null +++ b/src/backend/chain_monitor.py @@ -0,0 +1,273 @@ +""" +Chain monitor for Kinitro backend. + +Monitors Bittensor chain for miner commitments and processes them. +Extracted from BackendService for better separation of concerns. +""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timezone +from typing import Any, Callable, Coroutine, Dict, List, Optional + +from fiber.chain.fetch_nodes import _get_nodes_for_uid +from fiber.chain.models import Node +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from core.chain import query_commitments_from_substrate +from core.log import get_logger +from core.schemas import ChainCommitmentResponse + +from .config import BackendConfig +from .models import BackendState, Competition, SS58Address + +logger = get_logger(__name__) + + +class ChainConfig: + """Configuration for the chain monitor.""" + + def __init__( + self, + max_commitment_lookback: int = 360, + chain_sync_interval: float = 30.0, + chain_scan_yield_interval: int = 2, + ): + self.max_commitment_lookback = max_commitment_lookback + self.chain_sync_interval = chain_sync_interval + self.chain_scan_yield_interval = chain_scan_yield_interval + + +CommitmentCallback = Callable[ + [ChainCommitmentResponse, int, Dict[str, Competition]], + Coroutine[Any, Any, None], +] + + +class ChainMonitor: + """ + Monitors Bittensor chain for commitments. + + This class is responsible for: + - Scanning blocks for miner commitments + - Syncing metagraph nodes + - Calling registered callbacks when commitments are found + """ + + def __init__( + self, + substrate: Any, + backend_config: BackendConfig, + session_factory: async_sessionmaker[AsyncSession], + config: ChainConfig, + thread_pool: ThreadPoolExecutor, + on_commitment: Optional[CommitmentCallback] = None, + ): + self.substrate = substrate + self.backend_config = backend_config + self.session_factory = session_factory + self.config = config + self.thread_pool = thread_pool + self.on_commitment = on_commitment + + # Chain state + self.nodes: Optional[Dict[SS58Address, Node]] = None + + # Task handle + self._task: Optional[asyncio.Task] = None + self._running = False + + async def start(self) -> None: + """Start the chain monitor background task.""" + if self._running: + logger.warning("ChainMonitor already running") + return + + self._running = True + self._task = asyncio.create_task(self._monitor_loop()) + logger.info("ChainMonitor started") + + async def stop(self) -> None: + """Stop the chain monitor background task.""" + self._running = False + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + logger.info("ChainMonitor task cancelled") + self._task = None + logger.info("ChainMonitor stopped") + + async def _monitor_loop(self) -> None: + """Background task to monitor blockchain for commitments.""" + while self._running: + try: + await self._scan_once() + await asyncio.sleep(self.config.chain_sync_interval) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error monitoring chain: {e}") + await asyncio.sleep(self.config.chain_sync_interval) + + async def _scan_once(self) -> None: + """Perform a single scan of the blockchain.""" + if not self.substrate or not self.session_factory: + return + + # Sync metagraph first + await self._sync_metagraph() + + if not self.nodes: + return + + async with self.session_factory() as session: + # Get backend state + state_result = await session.execute( + select(BackendState).where(BackendState.id == 1) + ) + state = state_result.scalar_one_or_none() + if not state: + logger.warning("Backend state not found") + return + + # Get latest block + latest_block = await self._get_latest_block() + if latest_block < 0: + return + + start_block = max( + state.last_seen_block + 1, + latest_block - self.config.max_commitment_lookback + 1, + ) + + logger.info(f"Checking blocks {start_block} to {latest_block}") + + # Get active competitions + comp_result = await session.execute( + select(Competition).where(Competition.active) + ) + active_competitions = {c.id: c for c in comp_result.scalars()} + logger.debug( + f"Preview of active competitions: {list(active_competitions.keys())[:5]}" + ) + + # Scan blocks and process commitments + await self.scan_blocks(start_block, latest_block, active_competitions) + + # Update state + state.last_seen_block = latest_block + state.last_chain_scan = datetime.now(timezone.utc) + await session.commit() + + async def scan_blocks( + self, + start: int, + end: int, + active_competitions: Dict[str, Competition], + ) -> List[ChainCommitmentResponse]: + """ + Scan a range of blocks for commitments. + + Args: + start: Starting block number + end: Ending block number + active_competitions: Currently active competitions + + Returns: + List of commitments found + """ + all_commitments: List[ChainCommitmentResponse] = [] + + for i, block_num in enumerate(range(start, end + 1)): + commitments = await self._query_block_commitments(block_num) + all_commitments.extend(commitments) + + for commitment in commitments: + if self.on_commitment: + await self.on_commitment(commitment, block_num, active_competitions) + + # Yield control periodically to prevent blocking WebSocket connections + if i % self.config.chain_scan_yield_interval == 0: + await asyncio.sleep(0) + + return all_commitments + + async def _get_latest_block(self) -> int: + """Get latest block from chain. + + Returns: + int: The latest block number, or -1 if an error occurred. + """ + try: + if not self.substrate: + return -1 + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.thread_pool, self.substrate.get_block_number + ) + except Exception as e: + logger.error(f"Failed to get latest block: {e}") + return -1 + + def _sync_nodes_sync(self) -> None: + """Synchronous version of node syncing for thread pool.""" + node_list = _get_nodes_for_uid( + self.substrate, self.backend_config.settings["subtensor"]["netuid"] + ) + self.nodes = {node.hotkey: node for node in node_list} + + async def _sync_metagraph(self) -> None: + """Sync metagraph nodes with memory leak prevention.""" + try: + loop = asyncio.get_event_loop() + await loop.run_in_executor(self.thread_pool, self._sync_nodes_sync) + logger.debug("Nodes synced") + except Exception as e: + logger.error(f"Failed to sync metagraph: {e}") + + def _query_commitments_sync( + self, block_num: int, nodes: list + ) -> List[ChainCommitmentResponse]: + """Synchronous version of commitment querying for thread pool.""" + commitments = [] + + for node in nodes: + try: + miner_commitments = query_commitments_from_substrate( + self.backend_config, self.substrate, node.hotkey, block=block_num + ) + if miner_commitments: + commitments.extend(miner_commitments) + except Exception as e: + logger.debug(f"Failed to query {node.hotkey}: {e}") + continue + + return commitments + + async def _query_block_commitments( + self, block_num: int + ) -> List[ChainCommitmentResponse]: + """Query commitments for a block.""" + try: + if not self.nodes: + return [] + + node_list = list(self.nodes.values()) + + loop = asyncio.get_event_loop() + commitments = await loop.run_in_executor( + self.thread_pool, self._query_commitments_sync, block_num, node_list + ) + + return commitments + + except Exception as e: + logger.error(f"Failed to query block {block_num}: {e}") + return [] + + def get_nodes(self) -> Optional[Dict[SS58Address, Node]]: + """Get the current node mapping.""" + return self.nodes diff --git a/src/backend/endpoints.py b/src/backend/endpoints.py index f71e736..81bb240 100644 --- a/src/backend/endpoints.py +++ b/src/backend/endpoints.py @@ -93,6 +93,7 @@ EvaluationLogDownloadResponse, EvaluationResultLogResponse, EvaluationResultResponse, + EvaluatorInfoResponse, JobResponse, JobStatusResponse, LeaderCandidateReviewRequest, @@ -2398,35 +2399,41 @@ async def delete_api_key(key_id: int): # ============================================================================ -# WebSocket Endpoint for Validators +# WebSocket Endpoint for Evaluators (Direct Connection) # ============================================================================ -@app.websocket("/ws/validator") -async def validator_websocket(websocket: WebSocket): - """WebSocket endpoint for validator connections.""" +@app.websocket("/ws/evaluator") +async def evaluator_websocket(websocket: WebSocket): + """WebSocket endpoint for direct evaluator connections. + + This enables evaluators to connect directly to the backend without + going through a validator relay. Evaluators register with their + capabilities and receive jobs directly. + """ await websocket.accept() - # generate a connection id connection_id = uuid.uuid4().hex + evaluator_id: Optional[str] = None try: - # Wait for registration + # Wait for registration message data = await websocket.receive_text() message = json.loads(data) - if message.get("message_type") != MessageType.REGISTER: - await websocket.send_text(json.dumps({"error": "Must register first"})) + if message.get("message_type") != MessageType.EVALUATOR_REGISTER: + await websocket.send_text( + json.dumps({"error": "Must send EVALUATOR_REGISTER first"}) + ) await websocket.close() return - # Check for API key in registration message + # Validate API key api_key = message.get("api_key") if not api_key: await websocket.send_text(json.dumps({"error": "Missing API key"})) await websocket.close() return - # Validate API key api_key_obj = await get_api_key_from_db(api_key, backend_service) if not api_key_obj: await websocket.send_text( @@ -2435,94 +2442,83 @@ async def validator_websocket(websocket: WebSocket): await websocket.close() return - # Check if API key has validator role - if ( - api_key_obj.role != UserRole.VALIDATOR - and api_key_obj.role != UserRole.ADMIN - ): + # Check for evaluator or admin role + if api_key_obj.role not in (UserRole.EVALUATOR, UserRole.ADMIN): await websocket.send_text( - json.dumps( - {"error": "API key does not have access to validator endpoints"} - ) + json.dumps({"error": "API key does not have evaluator access"}) ) await websocket.close() return - validator_hotkey = message.get("hotkey") - if not validator_hotkey: - await websocket.send_text(json.dumps({"error": "Missing hotkey"})) + evaluator_id = message.get("evaluator_id") + if not evaluator_id: + await websocket.send_text(json.dumps({"error": "Missing evaluator_id"})) await websocket.close() return - # If API key has an associated hotkey, verify it matches - if ( - api_key_obj.associated_hotkey - and api_key_obj.associated_hotkey != validator_hotkey - ): - await websocket.send_text( - json.dumps({"error": "Hotkey does not match API key association"}) - ) - await websocket.close() - return + supported_task_types = message.get("supported_task_types", ["rl_rollout"]) + max_concurrent_jobs = message.get("max_concurrent_jobs", 1) + capabilities = message.get("capabilities") - # Register validator - if not backend_service.async_session: - await websocket.send_text(json.dumps({"error": "Database not initialized"})) - await websocket.close() - return - async with backend_service.async_session() as session: - result = await session.execute( - select(ValidatorConnection).where( - ValidatorConnection.validator_hotkey == validator_hotkey - ) - ) - validator_conn = result.scalar_one_or_none() - - if not validator_conn: - validator_conn = ValidatorConnection( - id=next(backend_service.id_generator), - validator_hotkey=validator_hotkey, - connection_id=connection_id, - api_key_id=api_key_obj.id, - is_connected=True, + # Register in evaluator hub + await backend_service.evaluator_hub.register( + connection_id=connection_id, + evaluator_id=evaluator_id, + websocket=websocket, + api_key_id=api_key_obj.id, + supported_task_types=supported_task_types, + max_concurrent_jobs=max_concurrent_jobs, + capabilities=capabilities, + ) + + # Persist to database + if backend_service.async_session: + async with backend_service.async_session() as session: + from .models import EvaluatorConnection + + result = await session.execute( + select(EvaluatorConnection).where( + EvaluatorConnection.evaluator_id == evaluator_id + ) ) - session.add(validator_conn) - else: - validator_conn.connection_id = connection_id - validator_conn.api_key_id = api_key_obj.id - validator_conn.last_connected_at = datetime.now(timezone.utc) - validator_conn.last_heartbeat = datetime.now(timezone.utc) - validator_conn.is_connected = True - - await session.commit() + evaluator_conn = result.scalar_one_or_none() + + if not evaluator_conn: + evaluator_conn = EvaluatorConnection( + id=next(backend_service.id_generator), + evaluator_id=evaluator_id, + api_key_id=api_key_obj.id, + supported_task_types=supported_task_types, + max_concurrent_jobs=max_concurrent_jobs, + is_connected=True, + capabilities=capabilities, + ) + session.add(evaluator_conn) + else: + evaluator_conn.api_key_id = api_key_obj.id + evaluator_conn.supported_task_types = supported_task_types + evaluator_conn.max_concurrent_jobs = max_concurrent_jobs + evaluator_conn.last_heartbeat = datetime.now(timezone.utc) + evaluator_conn.is_connected = True + evaluator_conn.capabilities = capabilities - # Store connection - backend_service.active_connections[connection_id] = websocket - backend_service.validator_connections[connection_id] = validator_hotkey + await session.commit() # Send acknowledgment await websocket.send_text( json.dumps( { - "message_type": MessageType.REGISTRATION_ACK, - "status": "registered", + "message_type": MessageType.EVALUATOR_REGISTRATION_ACK, + "success": True, "timestamp": datetime.now(timezone.utc).isoformat(), } ) ) - logger.info(f"Validator registered: {validator_hotkey} ({connection_id})") - - # Broadcast validator connected event - event = ValidatorConnectedEvent( - validator_hotkey=validator_hotkey, - connection_id=connection_id, - connected_at=datetime.now(timezone.utc), + logger.info( + f"Evaluator registered: {evaluator_id} ({connection_id}) " + f"with {max_concurrent_jobs} max concurrent jobs" ) - await event_broadcaster.broadcast_event(EventType.VALIDATOR_CONNECTED, event) - - # Broadcast updated stats (validator count changed) - await backend_service._broadcast_stats_update() # Handle messages while True: @@ -2531,9 +2527,23 @@ async def validator_websocket(websocket: WebSocket): message_type = message.get("message_type") if message_type == MessageType.HEARTBEAT: - await backend_service.queue_validator_heartbeat( - validator_hotkey, datetime.now(timezone.utc) - ) + # Update heartbeat in hub + backend_service.evaluator_hub.update_heartbeat(evaluator_id) + + # Update database + if backend_service.async_session: + async with backend_service.async_session() as session: + from .models import EvaluatorConnection + + result = await session.execute( + select(EvaluatorConnection).where( + EvaluatorConnection.evaluator_id == evaluator_id + ) + ) + evaluator_conn = result.scalar_one_or_none() + if evaluator_conn: + evaluator_conn.last_heartbeat = datetime.now(timezone.utc) + await session.commit() await websocket.send_text( json.dumps( @@ -2545,201 +2555,262 @@ async def validator_websocket(websocket: WebSocket): ) elif message_type == MessageType.EVAL_RESULT: - # Handle evaluation result + # Handle evaluation result from evaluator result_msg = EvalResultMessage(**message) - if not backend_service.async_session: - logger.error("Database not initialized") - continue - async with backend_service.async_session() as session: - # Find job - job_result = await session.execute( - select(BackendEvaluationJob).where( - BackendEvaluationJob.id == result_msg.job_id - ) - ) - backend_job = job_result.scalar_one_or_none() - - if backend_job: - # Create result - eval_result = BackendEvaluationResult( - id=next(backend_service.id_generator), - job_id=result_msg.job_id, - validator_hotkey=validator_hotkey, - miner_hotkey=result_msg.miner_hotkey, - competition_id=result_msg.competition_id, - env_provider=result_msg.env_provider, - benchmark=result_msg.benchmark_name, - score=result_msg.score, - success_rate=result_msg.success_rate, - avg_reward=result_msg.avg_reward, - total_episodes=result_msg.total_episodes, - logs=result_msg.logs, - error=result_msg.error, - extra_data=result_msg.extra_data, - env_specs=result_msg.env_specs, - ) + if backend_service.async_session: + async with backend_service.async_session() as session: + from .models import EvaluatorConnection - session.add(eval_result) - - # Update validator stats - val_result = await session.execute( - select(ValidatorConnection).where( - ValidatorConnection.validator_hotkey == validator_hotkey + # Find job + job_result = await session.execute( + select(BackendEvaluationJob).where( + BackendEvaluationJob.id == result_msg.job_id ) ) - validator_conn = val_result.scalar_one_or_none() - if validator_conn: - validator_conn.total_results_received += 1 - if result_msg.error: - validator_conn.total_errors += 1 + backend_job = job_result.scalar_one_or_none() + + if backend_job: + # Create result - use evaluator_id as validator_hotkey + # for compatibility with existing schema + eval_result = BackendEvaluationResult( + id=next(backend_service.id_generator), + job_id=result_msg.job_id, + validator_hotkey=evaluator_id, # Using evaluator_id + miner_hotkey=result_msg.miner_hotkey, + competition_id=result_msg.competition_id, + env_provider=result_msg.env_provider, + benchmark=result_msg.benchmark_name, + score=result_msg.score, + success_rate=result_msg.success_rate, + avg_reward=result_msg.avg_reward, + total_episodes=result_msg.total_episodes, + logs=result_msg.logs, + error=result_msg.error, + extra_data=result_msg.extra_data, + env_specs=result_msg.env_specs, + ) - await session.commit() + session.add(eval_result) - # Update job status based on result - result_status = getattr(result_msg, "status", None) - if result_status is None: - result_status = ( - EvaluationStatus.FAILED - if result_msg.error - else EvaluationStatus.COMPLETED + # Update evaluator stats + eval_conn_result = await session.execute( + select(EvaluatorConnection).where( + EvaluatorConnection.evaluator_id == evaluator_id + ) ) + evaluator_conn = eval_conn_result.scalar_one_or_none() + if evaluator_conn: + if result_msg.error: + evaluator_conn.total_jobs_failed += 1 + else: + evaluator_conn.total_jobs_completed += 1 + # Decrement current job count + if evaluator_conn.current_job_count > 0: + evaluator_conn.current_job_count -= 1 + + await session.commit() + + # Update job status + result_status = getattr(result_msg, "status", None) + if result_status is None: + result_status = ( + EvaluationStatus.FAILED + if result_msg.error + else EvaluationStatus.COMPLETED + ) - if result_status == EvaluationStatus.COMPLETED: detail = ( f"Evaluation completed with score {result_msg.score}" + if result_status == EvaluationStatus.COMPLETED + else result_msg.error or result_status.value ) - else: - detail = result_msg.error or result_status.value - - await backend_service._update_job_status( - result_msg.job_id, - validator_hotkey, - result_status, - detail, - ) - logger.info( - f"Stored result from {validator_hotkey} for job {result_msg.job_id}" - ) + await backend_service._update_job_status( + result_msg.job_id, + evaluator_id, + result_status, + detail, + ) - # Create evaluation completed event - # Pydantic will automatically handle datetime to ISO conversion - eval_event = EvaluationCompletedEvent( - job_id=eval_result.job_id, - validator_hotkey=eval_result.validator_hotkey, - miner_hotkey=eval_result.miner_hotkey, - competition_id=eval_result.competition_id, - benchmark_name=eval_result.benchmark, - score=eval_result.score, - success_rate=eval_result.success_rate, - avg_reward=eval_result.avg_reward, - total_episodes=eval_result.total_episodes, - result_time=eval_result.result_time, - created_at=eval_result.created_at, - ) - await event_broadcaster.broadcast_event( - EventType.EVALUATION_COMPLETED, eval_event - ) + logger.info( + f"Stored result from evaluator {evaluator_id} " + f"for job {result_msg.job_id}" + ) - # Send acknowledgment - await websocket.send_text( - json.dumps( - { - "message_type": MessageType.RESULT_ACK, - "job_id": str(result_msg.job_id), - "status": "received", - } + # Broadcast evaluation completed event + eval_event = EvaluationCompletedEvent( + job_id=eval_result.job_id, + validator_hotkey=evaluator_id, + miner_hotkey=eval_result.miner_hotkey, + competition_id=eval_result.competition_id, + benchmark_name=eval_result.benchmark, + score=eval_result.score, + success_rate=eval_result.success_rate, + avg_reward=eval_result.avg_reward, + total_episodes=eval_result.total_episodes, + result_time=eval_result.result_time, + created_at=eval_result.created_at, + ) + await event_broadcaster.broadcast_event( + EventType.EVALUATION_COMPLETED, eval_event ) - ) - elif message_type == MessageType.JOB_STATUS_UPDATE: - status_msg = JobStatusUpdateMessage(**message) + # Decrement job count in hub + backend_service.evaluator_hub.job_completed(evaluator_id) - if status_msg.validator_hotkey != validator_hotkey: - logger.warning( - "Validator %s attempted to update job %s with mismatched hotkey %s", - validator_hotkey, - status_msg.job_id, - status_msg.validator_hotkey, + # Send acknowledgment + await websocket.send_text( + json.dumps( + { + "message_type": MessageType.RESULT_ACK, + "job_id": str(result_msg.job_id), + "status": "received", + } ) - continue + ) + + elif message_type == MessageType.JOB_STATUS_UPDATE: + status_msg = JobStatusUpdateMessage(**message) logger.info( - "Received job status update for job %s from %s: %s", + "Received job status update from evaluator %s for job %s: %s", + evaluator_id, status_msg.job_id, - validator_hotkey, status_msg.status, ) await backend_service._update_job_status( status_msg.job_id, - validator_hotkey, + evaluator_id, status_msg.status, status_msg.detail, ) elif message_type == MessageType.EPISODE_DATA: episode_msg = EpisodeDataMessage(**message) - await backend_service.queue_episode_data(validator_hotkey, episode_msg) + await backend_service.queue_episode_data(evaluator_id, episode_msg) logger.debug( - "Queued episode data from %s for episode %s for submission %s", - validator_hotkey, + "Queued episode data from evaluator %s for episode %s", + evaluator_id, episode_msg.episode_id, - episode_msg.submission_id, ) elif message_type == MessageType.EPISODE_STEP_DATA: step_msg = EpisodeStepDataMessage(**message) - await backend_service.queue_episode_step_data( - validator_hotkey, step_msg - ) + await backend_service.queue_episode_step_data(evaluator_id, step_msg) logger.debug( - "Queued step data from %s for episode %s step %s", - validator_hotkey, + "Queued step data from evaluator %s for episode %s step %s", + evaluator_id, step_msg.episode_id, step_msg.step, ) + elif message_type == MessageType.JOB_ACK: + # Evaluator acknowledging job receipt + job_id = message.get("job_id") + accepted = message.get("accepted", True) + reason = message.get("reason") + + if not accepted: + logger.warning( + "Evaluator %s rejected job %s: %s", + evaluator_id, + job_id, + reason, + ) + # Could re-route job to another evaluator here + else: + logger.debug( + "Evaluator %s accepted job %s", + evaluator_id, + job_id, + ) + except WebSocketDisconnect: - logger.info(f"Validator disconnected: {connection_id}") + logger.info(f"Evaluator disconnected: {evaluator_id} ({connection_id})") except Exception as e: - logger.error(f"Error in validator WebSocket: {e}") + logger.error(f"Error in evaluator WebSocket: {e}") finally: # Cleanup - if connection_id in backend_service.active_connections: - del backend_service.active_connections[connection_id] - if connection_id in backend_service.validator_connections: - hotkey = backend_service.validator_connections[connection_id] - del backend_service.validator_connections[connection_id] + if evaluator_id: + await backend_service.evaluator_hub.unregister(connection_id) # Update database if backend_service.async_session: async with backend_service.async_session() as session: + from .models import EvaluatorConnection + result = await session.execute( - select(ValidatorConnection).where( - ValidatorConnection.validator_hotkey == hotkey + select(EvaluatorConnection).where( + EvaluatorConnection.evaluator_id == evaluator_id ) ) - validator_conn = result.scalar_one_or_none() - if validator_conn: - validator_conn.is_connected = False + evaluator_conn = result.scalar_one_or_none() + if evaluator_conn: + evaluator_conn.is_connected = False await session.commit() - # Broadcast validator disconnected event - disconnected_event = ValidatorDisconnectedEvent( - validator_hotkey=hotkey, - connection_id=connection_id, - disconnected_at=datetime.now(timezone.utc), + +# ============================================================================ +# Evaluator REST Endpoints +# ============================================================================ + + +@app.get("/evaluators", response_model=List[EvaluatorInfoResponse]) +async def list_evaluators( + connected_only: bool = Query( + False, description="Filter for connected evaluators only" + ), + skip: int = Query(0, ge=0), + limit: int = Query(DEFAULT_PAGE_LIMIT, ge=MIN_PAGE_LIMIT, le=MAX_PAGE_LIMIT), +): + """List all evaluators.""" + if not backend_service.async_session: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Database not initialized", + ) + async with backend_service.async_session() as session: + from .models import EvaluatorConnection + + query = select(EvaluatorConnection) + if connected_only: + query = query.where(EvaluatorConnection.is_connected) + query = query.offset(skip).limit(limit) + + result = await session.execute(query) + evaluators = result.scalars().all() + + return [EvaluatorInfoResponse.model_validate(e) for e in evaluators] + + +@app.get("/evaluators/{evaluator_id}", response_model=EvaluatorInfoResponse) +async def get_evaluator(evaluator_id: str): + """Get a specific evaluator by ID.""" + if not backend_service.async_session: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Database not initialized", + ) + async with backend_service.async_session() as session: + from .models import EvaluatorConnection + + result = await session.execute( + select(EvaluatorConnection).where( + EvaluatorConnection.evaluator_id == evaluator_id ) - await event_broadcaster.broadcast_event( - EventType.VALIDATOR_DISCONNECTED, disconnected_event + ) + evaluator = result.scalar_one_or_none() + + if not evaluator: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Evaluator not found" ) - # Broadcast updated stats (validator count changed) - await backend_service._broadcast_stats_update() + return EvaluatorInfoResponse.model_validate(evaluator) # ============================================================================ diff --git a/src/backend/evaluator_hub.py b/src/backend/evaluator_hub.py new file mode 100644 index 0000000..6591197 --- /dev/null +++ b/src/backend/evaluator_hub.py @@ -0,0 +1,410 @@ +""" +Evaluator hub for Kinitro backend. + +Manages direct WebSocket connections from evaluators. +Handles job routing and evaluator lifecycle. +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Awaitable, Callable, Dict, List, Optional + +from fastapi import WebSocket +from sqlmodel import SQLModel + +from core.log import get_logger +from core.messages import EvalJobMessage + +logger = get_logger(__name__) + +ConnectionId = str # Unique ID for each WebSocket connection +EvaluatorId = str # Unique evaluator instance ID + + +@dataclass +class EvaluatorState: + """In-memory state for a connected evaluator.""" + + connection_id: ConnectionId + evaluator_id: EvaluatorId + websocket: WebSocket + api_key_id: Optional[int] + supported_task_types: List[str] + max_concurrent_jobs: int + current_job_count: int + last_heartbeat: datetime + capabilities: Optional[dict] + + +class EvaluatorHub: + """ + Manages WebSocket connections from evaluators. + + This class is responsible for: + - Registering and unregistering evaluator connections + - Routing jobs to evaluators based on capacity and task type + - Broadcasting messages to evaluators + - Tracking evaluator state and health + """ + + def __init__( + self, + on_result_received: Optional[Callable[[dict], Awaitable[None]]] = None, + on_status_update: Optional[Callable[[dict], Awaitable[None]]] = None, + ): + # WebSocket connections by connection_id + self._connections: Dict[ConnectionId, EvaluatorState] = {} + # Map evaluator_id -> connection_id for lookup + self._evaluator_to_connection: Dict[EvaluatorId, ConnectionId] = {} + # Round-robin index for job assignment + self._round_robin_index: int = 0 + # Callbacks + self._on_result_received = on_result_received + self._on_status_update = on_status_update + + async def register( + self, + connection_id: ConnectionId, + evaluator_id: EvaluatorId, + websocket: WebSocket, + api_key_id: Optional[int] = None, + supported_task_types: Optional[List[str]] = None, + max_concurrent_jobs: int = 1, + capabilities: Optional[dict] = None, + ) -> EvaluatorState: + """ + Register a new evaluator connection. + + If an evaluator with the same evaluator_id is already connected, + the old connection is closed and replaced. + + Args: + connection_id: Unique identifier for this WebSocket connection + evaluator_id: Unique identifier for the evaluator instance + websocket: The WebSocket connection + api_key_id: ID of the API key used for authentication + supported_task_types: List of task types this evaluator can handle + max_concurrent_jobs: Maximum concurrent jobs this evaluator can run + capabilities: Additional evaluator capabilities (GPU info, etc.) + + Returns: + The EvaluatorState for the registered connection + """ + # Check if this evaluator_id is already connected + if evaluator_id in self._evaluator_to_connection: + old_conn_id = self._evaluator_to_connection[evaluator_id] + logger.warning( + f"Evaluator {evaluator_id} reconnecting, closing old connection {old_conn_id}" + ) + await self._close_connection(old_conn_id) + + state = EvaluatorState( + connection_id=connection_id, + evaluator_id=evaluator_id, + websocket=websocket, + api_key_id=api_key_id, + supported_task_types=supported_task_types or ["rl_rollout"], + max_concurrent_jobs=max_concurrent_jobs, + current_job_count=0, + last_heartbeat=datetime.now(timezone.utc), + capabilities=capabilities, + ) + + self._connections[connection_id] = state + self._evaluator_to_connection[evaluator_id] = connection_id + + logger.info( + f"Registered evaluator {evaluator_id} (connection {connection_id}) " + f"with {max_concurrent_jobs} max concurrent jobs, " + f"supporting task types: {state.supported_task_types}" + ) + + return state + + async def unregister(self, connection_id: ConnectionId) -> Optional[EvaluatorId]: + """ + Unregister an evaluator connection. + + Args: + connection_id: The connection ID to unregister + + Returns: + The evaluator_id of the unregistered evaluator, or None if not found + """ + state = self._connections.pop(connection_id, None) + if not state: + return None + + self._evaluator_to_connection.pop(state.evaluator_id, None) + + try: + await state.websocket.close() + except Exception as e: + logger.debug(f"Error closing WebSocket for {connection_id}: {e}") + + logger.info( + f"Unregistered evaluator {state.evaluator_id} (connection {connection_id})" + ) + + return state.evaluator_id + + async def _close_connection(self, connection_id: ConnectionId) -> None: + """Close a connection without full unregistration.""" + state = self._connections.get(connection_id) + if state: + try: + await state.websocket.close() + except Exception as e: + logger.debug(f"Error closing WebSocket for {connection_id}: {e}") + + async def send_job(self, evaluator_id: EvaluatorId, job: EvalJobMessage) -> bool: + """ + Send a job to a specific evaluator. + + Args: + evaluator_id: The evaluator to send the job to + job: The job message + + Returns: + True if the job was sent successfully + """ + conn_id = self._evaluator_to_connection.get(evaluator_id) + if not conn_id: + logger.warning(f"No connection found for evaluator {evaluator_id}") + return False + + state = self._connections.get(conn_id) + if not state: + logger.warning(f"No state found for connection {conn_id}") + return False + + try: + await state.websocket.send_text(job.model_dump_json()) + state.current_job_count += 1 + logger.debug( + f"Sent job {job.job_id} to evaluator {evaluator_id} " + f"(current jobs: {state.current_job_count})" + ) + return True + except Exception as e: + logger.error(f"Failed to send job to evaluator {evaluator_id}: {e}") + # Don't increment job count on failure + return False + + async def route_job( + self, job: EvalJobMessage, task_type: str = "rl_rollout" + ) -> Optional[EvaluatorId]: + """ + Route a job to an available evaluator using round-robin with capacity check. + + Args: + job: The job message to route + task_type: The task type required for this job + + Returns: + The evaluator_id that accepted the job, or None if no evaluator available + """ + # Get evaluators that support this task type and have capacity + available = self._get_available_evaluators(task_type) + if not available: + logger.warning( + f"No available evaluators for task type {task_type}, " + f"total connected: {len(self._connections)}" + ) + return None + + # Round-robin selection + self._round_robin_index = self._round_robin_index % len(available) + selected_evaluator_id = available[self._round_robin_index] + self._round_robin_index += 1 + + # Send the job + if await self.send_job(selected_evaluator_id, job): + return selected_evaluator_id + + return None + + async def broadcast_job(self, job: EvalJobMessage) -> int: + """ + Broadcast a job to all connected evaluators. + + Args: + job: The job message to broadcast + + Returns: + Number of evaluators that received the job + """ + sent_count = 0 + failed_connections: List[ConnectionId] = [] + + for conn_id, state in list(self._connections.items()): + try: + await state.websocket.send_text(job.model_dump_json()) + sent_count += 1 + except Exception as e: + logger.error(f"Failed to broadcast to {state.evaluator_id}: {e}") + failed_connections.append(conn_id) + + # Clean up failed connections + for conn_id in failed_connections: + await self.unregister(conn_id) + + return sent_count + + async def send_message(self, evaluator_id: EvaluatorId, message: SQLModel) -> bool: + """ + Send an arbitrary message to a specific evaluator. + + Args: + evaluator_id: The evaluator to send to + message: The message to send + + Returns: + True if sent successfully + """ + conn_id = self._evaluator_to_connection.get(evaluator_id) + if not conn_id: + return False + + state = self._connections.get(conn_id) + if not state: + return False + + try: + await state.websocket.send_text(message.model_dump_json()) + return True + except Exception as e: + logger.error(f"Failed to send message to {evaluator_id}: {e}") + return False + + async def broadcast_message(self, message: SQLModel) -> int: + """ + Broadcast a message to all connected evaluators. + + Args: + message: The message to broadcast + + Returns: + Number of evaluators that received the message + """ + sent_count = 0 + failed_connections: List[ConnectionId] = [] + + for conn_id, state in list(self._connections.items()): + try: + await state.websocket.send_text(message.model_dump_json()) + sent_count += 1 + except Exception as e: + logger.error(f"Failed to broadcast to {state.evaluator_id}: {e}") + failed_connections.append(conn_id) + + for conn_id in failed_connections: + await self.unregister(conn_id) + + return sent_count + + def update_heartbeat(self, evaluator_id: EvaluatorId) -> bool: + """ + Update the last heartbeat time for an evaluator. + + Args: + evaluator_id: The evaluator ID + + Returns: + True if the evaluator was found and updated + """ + conn_id = self._evaluator_to_connection.get(evaluator_id) + if not conn_id: + return False + + state = self._connections.get(conn_id) + if not state: + return False + + state.last_heartbeat = datetime.now(timezone.utc) + return True + + def job_completed(self, evaluator_id: EvaluatorId) -> bool: + """ + Mark a job as completed for an evaluator (decrement job count). + + Args: + evaluator_id: The evaluator ID + + Returns: + True if the evaluator was found and updated + """ + conn_id = self._evaluator_to_connection.get(evaluator_id) + if not conn_id: + return False + + state = self._connections.get(conn_id) + if not state: + return False + + if state.current_job_count > 0: + state.current_job_count -= 1 + + return True + + def _get_available_evaluators(self, task_type: str) -> List[EvaluatorId]: + """ + Get list of evaluator IDs that support a task type and have capacity. + + Args: + task_type: The task type to check for + + Returns: + List of evaluator IDs with capacity + """ + available = [] + for state in self._connections.values(): + if task_type in state.supported_task_types: + if state.current_job_count < state.max_concurrent_jobs: + available.append(state.evaluator_id) + return available + + def get_evaluator_ids(self) -> List[EvaluatorId]: + """Get list of all connected evaluator IDs.""" + return list(self._evaluator_to_connection.keys()) + + def get_state(self, evaluator_id: EvaluatorId) -> Optional[EvaluatorState]: + """Get the state for an evaluator.""" + conn_id = self._evaluator_to_connection.get(evaluator_id) + if not conn_id: + return None + return self._connections.get(conn_id) + + def get_connection_for_evaluator( + self, evaluator_id: EvaluatorId + ) -> Optional[ConnectionId]: + """Get the connection ID for an evaluator.""" + return self._evaluator_to_connection.get(evaluator_id) + + def has_connections(self) -> bool: + """Check if there are any active connections.""" + return bool(self._connections) + + def connection_count(self) -> int: + """Get the number of active connections.""" + return len(self._connections) + + def total_capacity(self) -> int: + """Get the total job capacity across all connected evaluators.""" + return sum(state.max_concurrent_jobs for state in self._connections.values()) + + def available_capacity(self) -> int: + """Get the available job capacity across all connected evaluators.""" + return sum( + state.max_concurrent_jobs - state.current_job_count + for state in self._connections.values() + ) + + async def close_all(self) -> None: + """Close all WebSocket connections.""" + for conn_id in list(self._connections.keys()): + await self.unregister(conn_id) + + self._connections.clear() + self._evaluator_to_connection.clear() + logger.info("All evaluator connections closed") diff --git a/src/backend/job_scheduler.py b/src/backend/job_scheduler.py new file mode 100644 index 0000000..180df2e --- /dev/null +++ b/src/backend/job_scheduler.py @@ -0,0 +1,460 @@ +""" +Job scheduler for Kinitro backend. + +Creates and routes evaluation jobs to evaluators via direct WebSocket connection. +Extracted from BackendService for better separation of concerns. +""" + +import copy +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Sequence + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from core.db.models import EvaluationStatus +from core.log import get_logger +from core.messages import EvalJobMessage + +from .constants import EVAL_JOB_TIMEOUT +from .events import JobCreatedEvent +from .models import ( + BackendEvaluationJob, + Competition, + MinerSubmission, +) +from .realtime import EventType, event_broadcaster + +if TYPE_CHECKING: + from .evaluator_hub import EvaluatorHub + +logger = get_logger(__name__) + + +class JobConfig: + """Configuration for job scheduling.""" + + def __init__( + self, + default_job_timeout_seconds: int = int(EVAL_JOB_TIMEOUT.total_seconds()), + submission_download_url_ttl: int = 21600, # 6 hours + ): + self.default_job_timeout_seconds = default_job_timeout_seconds + self.submission_download_url_ttl = submission_download_url_ttl + + +def _extract_benchmark_spec_payload( + config: Mapping[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Split a stored benchmark configuration into the full benchmark spec and the + underlying execution config used by the evaluator. + """ + spec_copy = copy.deepcopy(dict(config)) + try: + base_config_source = config["config"] + except KeyError as exc: + raise ValueError("Benchmark spec is missing 'config' payload") from exc + try: + base_config = copy.deepcopy(dict(base_config_source)) + except TypeError as exc: + raise ValueError("'config' payload must be a mapping") from exc + return spec_copy, base_config + + +def _normalize_benchmark_spec_payload( + provider: str, + benchmark_name: str, + payload: Mapping[str, Any] | None, +) -> dict[str, Any]: + """ + Ensure a benchmark specification payload includes top-level metadata and a nested config mapping. + + Accepts either the new-style payload (with a `config` key) or a bare config mapping and + returns a copy that always matches the new-style structure. + """ + if payload is None: + base_config: dict[str, Any] = {} + return { + "provider": provider, + "benchmark_name": benchmark_name, + "config": base_config, + } + + payload_dict = dict(payload) + if "config" in payload_dict: + return copy.deepcopy(payload_dict) + + base_config = copy.deepcopy(payload_dict) + return { + "provider": provider, + "benchmark_name": benchmark_name, + "config": base_config, + } + + +class SubmissionNotFoundError(Exception): + """Raised when a submission cannot be located.""" + + +class EvaluationJobNotFoundError(Exception): + """Raised when an evaluation job cannot be located.""" + + +class NoBenchmarksAvailableError(Exception): + """Raised when no benchmarks are available for an evaluation rerun.""" + + +class JobScheduler: + """ + Creates and routes evaluation jobs to evaluators. + + This class is responsible for: + - Creating evaluation jobs for submissions + - Routing jobs to connected evaluators via WebSocket + - Handling job reruns + - Monitoring for stale jobs + """ + + def __init__( + self, + session_factory: async_sessionmaker[AsyncSession], + evaluator_hub: "EvaluatorHub", + config: JobConfig, + id_generator, + submission_storage=None, + ): + self.session_factory = session_factory + self.evaluator_hub = evaluator_hub + self.config = config + self.id_generator = id_generator + self.submission_storage = submission_storage + + def _job_timeout_seconds(self, competition: Optional[Competition]) -> int: + """Return the timeout for a competition, falling back to the default.""" + if competition and competition.job_timeout_seconds: + try: + value = int(competition.job_timeout_seconds) + if value > 0: + return value + except (TypeError, ValueError): + logger.warning( + "Invalid job_timeout_seconds for competition %s: %r", + getattr(competition, "id", "unknown"), + getattr(competition, "job_timeout_seconds", None), + ) + return self.config.default_job_timeout_seconds + + async def create_jobs_for_submission( + self, + submission: MinerSubmission, + competition: Competition, + ) -> List[BackendEvaluationJob]: + """Create evaluation jobs for a submission based on competition benchmarks.""" + jobs: List[BackendEvaluationJob] = [] + + for benchmark in competition.benchmarks: + if "provider" not in benchmark or "benchmark_name" not in benchmark: + logger.error( + "Benchmark missing provider or benchmark_name: %s", benchmark + ) + continue + + if isinstance(benchmark, dict): + benchmark_spec = copy.deepcopy(benchmark) + else: + logger.warning( + "Benchmark specification for competition %s is not a dict (%r); " + "wrapping in config field", + competition.id, + type(benchmark), + ) + benchmark_spec = {"config": benchmark} + + job = BackendEvaluationJob( + id=next(self.id_generator), + submission_id=submission.id, + competition_id=competition.id, + miner_hotkey=submission.miner_hotkey, + hf_repo_id=submission.hf_repo_id, + env_provider=benchmark_spec.get("provider", benchmark["provider"]), + benchmark_name=benchmark_spec.get( + "benchmark_name", benchmark["benchmark_name"] + ), + config=benchmark_spec, + timeout_seconds=self._job_timeout_seconds(competition), + artifact_object_key=submission.artifact_object_key, + artifact_sha256=submission.artifact_sha256, + artifact_size_bytes=submission.artifact_size_bytes, + ) + jobs.append(job) + + return jobs + + async def schedule_submission( + self, + submission: MinerSubmission, + competition: Competition, + session: AsyncSession, + ) -> List[BackendEvaluationJob]: + """Create and persist evaluation jobs for a submission, then broadcast them.""" + jobs = await self.create_jobs_for_submission(submission, competition) + + if not jobs: + logger.error( + "No evaluation jobs generated for submission %s", submission.id + ) + return [] + + session.add_all(jobs) + await session.flush() + + return jobs + + async def publish_jobs(self, jobs: Sequence[BackendEvaluationJob]) -> None: + """Emit events and route jobs to evaluators.""" + if not jobs: + return + + connected_evaluator_ids = self.evaluator_hub.get_evaluator_ids() + + for job in jobs: + _benchmark_spec_payload, base_config_payload = ( + _extract_benchmark_spec_payload(job.config) + ) + job_event = JobCreatedEvent( + job_id=str(job.id), + competition_id=job.competition_id, + submission_id=job.submission_id, + miner_hotkey=job.miner_hotkey, + hf_repo_id=job.hf_repo_id, + env_provider=job.env_provider, + benchmark_name=job.benchmark_name, + config=base_config_payload, + status=EvaluationStatus.QUEUED, + validator_statuses={ + evaluator_id: EvaluationStatus.QUEUED + for evaluator_id in connected_evaluator_ids + }, + ) + + try: + await event_broadcaster.broadcast_event( + EventType.JOB_CREATED, job_event + ) + except Exception as exc: + logger.error( + "Failed to broadcast job created event for job %s: %s", + job.id, + exc, + ) + + try: + await self.broadcast_job(job) + except Exception as exc: + logger.error( + "Failed to broadcast job %s to evaluators: %s", job.id, exc + ) + + def _build_job_message(self, job: BackendEvaluationJob) -> Optional[EvalJobMessage]: + """Build an EvalJobMessage from a BackendEvaluationJob. + + Returns: + EvalJobMessage or None if artifact URL cannot be generated + """ + artifact_url = None + artifact_expires_at: Optional[datetime] = None + if self.submission_storage and job.artifact_object_key: + try: + artifact_url, artifact_expires_at = ( + self.submission_storage.generate_download_url( + job.artifact_object_key, self.config.submission_download_url_ttl + ) + ) + except Exception as exc: + logger.error( + "Failed to generate artifact URL for job %s: %s", job.id, exc + ) + return None + else: + logger.error( + "Cannot build job message for submission %s: storage unavailable or artifact missing", + job.submission_id, + ) + return None + + benchmark_spec_payload, base_config_payload = _extract_benchmark_spec_payload( + job.config + ) + + timeout_seconds = job.timeout_seconds or self.config.default_job_timeout_seconds + + return EvalJobMessage( + job_id=job.id, + competition_id=job.competition_id, + submission_id=job.submission_id, + miner_hotkey=job.miner_hotkey, + hf_repo_id=job.hf_repo_id, + env_provider=job.env_provider, + benchmark_name=job.benchmark_name, + config=base_config_payload, + benchmark_spec=benchmark_spec_payload, + artifact_url=artifact_url, + artifact_expires_at=artifact_expires_at, + artifact_sha256=job.artifact_sha256, + artifact_size_bytes=job.artifact_size_bytes, + timeout=timedelta(seconds=timeout_seconds), + ) + + async def broadcast_job(self, job: BackendEvaluationJob) -> int: + """Broadcast job to all connected evaluators. + + Returns: + Number of evaluators that received the job + """ + if not self.evaluator_hub.has_connections(): + logger.warning("No evaluators connected") + return 0 + + job_msg = self._build_job_message(job) + if not job_msg: + return 0 + + broadcast_count = await self.evaluator_hub.broadcast_job(job_msg) + + logger.info(f"Broadcasted job {job.id} to {broadcast_count} evaluators") + return broadcast_count + + async def rerun_submission_evaluations( + self, + submission_id: int, + benchmark_names: Optional[List[str]] = None, + requested_by_api_key_id: Optional[int] = None, + ) -> List[BackendEvaluationJob]: + """Re-run evaluations for a submission across its configured benchmarks.""" + benchmark_filter = ( + {name.strip() for name in benchmark_names if name.strip()} + if benchmark_names + else None + ) + + new_jobs: List[BackendEvaluationJob] = [] + + async with self.session_factory() as session: + submission = await session.get(MinerSubmission, submission_id) + if not submission: + raise SubmissionNotFoundError(f"Submission {submission_id} not found") + + competition = await session.get(Competition, submission.competition_id) + if not competition: + raise SubmissionNotFoundError( + f"Competition {submission.competition_id} not found for submission {submission_id}" + ) + + benchmarks = competition.benchmarks or [] + for benchmark in benchmarks: + provider = benchmark.get("provider") + benchmark_name = benchmark.get("benchmark_name") + + if not provider or not benchmark_name: + logger.error( + "Submission %s rerun skipped invalid benchmark entry: %s", + submission_id, + benchmark, + ) + continue + + if benchmark_filter and benchmark_name not in benchmark_filter: + continue + + spec_payload = _normalize_benchmark_spec_payload( + provider, + benchmark_name, + benchmark, + ) + + job = BackendEvaluationJob( + id=next(self.id_generator), + submission_id=submission.id, + competition_id=competition.id, + miner_hotkey=submission.miner_hotkey, + hf_repo_id=submission.hf_repo_id, + env_provider=provider, + benchmark_name=benchmark_name, + config=spec_payload, + timeout_seconds=self._job_timeout_seconds(competition), + artifact_object_key=submission.artifact_object_key, + artifact_sha256=submission.artifact_sha256, + artifact_size_bytes=submission.artifact_size_bytes, + ) + new_jobs.append(job) + + if not new_jobs: + raise NoBenchmarksAvailableError( + "No matching benchmarks available for rerun request" + ) + + session.add_all(new_jobs) + await session.commit() + + for job in new_jobs: + await session.refresh(job) + + await self.publish_jobs(new_jobs) + + logger.info( + "Submission %s rerun triggered by API key %s; queued %s jobs", + submission_id, + requested_by_api_key_id, + len(new_jobs), + ) + + return new_jobs + + async def rerun_job_evaluation( + self, + job_id: int, + requested_by_api_key_id: Optional[int] = None, + ) -> BackendEvaluationJob: + """Re-run a specific evaluation job by cloning its configuration.""" + async with self.session_factory() as session: + existing_job = await session.get(BackendEvaluationJob, job_id) + if not existing_job: + raise EvaluationJobNotFoundError(f"Job {job_id} not found") + + spec_payload = _normalize_benchmark_spec_payload( + existing_job.env_provider, + existing_job.benchmark_name, + existing_job.config + if isinstance(existing_job.config, Mapping) + else None, + ) + + new_job = BackendEvaluationJob( + id=next(self.id_generator), + submission_id=existing_job.submission_id, + competition_id=existing_job.competition_id, + miner_hotkey=existing_job.miner_hotkey, + hf_repo_id=existing_job.hf_repo_id, + env_provider=existing_job.env_provider, + benchmark_name=existing_job.benchmark_name, + config=spec_payload, + timeout_seconds=existing_job.timeout_seconds, + artifact_object_key=existing_job.artifact_object_key, + artifact_sha256=existing_job.artifact_sha256, + artifact_size_bytes=existing_job.artifact_size_bytes, + ) + + session.add(new_job) + await session.commit() + await session.refresh(new_job) + + await self.publish_jobs([new_job]) + + logger.info( + "Job %s rerun created as job %s by API key %s", + job_id, + new_job.id, + requested_by_api_key_id, + ) + + return new_job diff --git a/src/backend/migrations/versions/020_add_evaluator_connections.py b/src/backend/migrations/versions/020_add_evaluator_connections.py new file mode 100644 index 0000000..93e3e66 --- /dev/null +++ b/src/backend/migrations/versions/020_add_evaluator_connections.py @@ -0,0 +1,114 @@ +"""Add evaluator_connections table for direct evaluator communication. + +Revision ID: 020_add_evaluator_connections +Revises: 019_competition_uploads_nonneg +Create Date: 2025-01-15 00:00:00 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "020_add_evaluator_connections" +down_revision = "019_competition_uploads_nonneg" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "evaluator_connections", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column( + "evaluator_id", sa.String(128), nullable=False, unique=True, index=True + ), + sa.Column( + "api_key_id", + sa.BigInteger(), + sa.ForeignKey("api_keys.id"), + nullable=True, + index=True, + ), + sa.Column( + "supported_task_types", + sa.JSON(), + nullable=False, + server_default='["rl_rollout"]', + ), + sa.Column( + "max_concurrent_jobs", sa.Integer(), nullable=False, server_default="1" + ), + sa.Column( + "current_job_count", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column("is_connected", sa.Boolean(), nullable=False, server_default="true"), + sa.Column( + "last_heartbeat", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "first_connected_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "total_jobs_assigned", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column( + "total_jobs_completed", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column( + "total_jobs_failed", sa.Integer(), nullable=False, server_default="0" + ), + sa.Column("capabilities", sa.JSON(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + onupdate=sa.func.now(), + ), + sa.CheckConstraint( + "current_job_count >= 0", name="ck_evaluator_concurrent_non_negative" + ), + sa.CheckConstraint( + "max_concurrent_jobs > 0", name="ck_evaluator_max_concurrent_positive" + ), + sa.CheckConstraint( + "total_jobs_assigned >= 0", name="ck_evaluator_jobs_assigned_non_negative" + ), + sa.CheckConstraint( + "total_jobs_completed >= 0", name="ck_evaluator_jobs_completed_non_negative" + ), + sa.CheckConstraint( + "total_jobs_failed >= 0", name="ck_evaluator_jobs_failed_non_negative" + ), + ) + op.create_index( + "ix_evaluator_connections_connected", "evaluator_connections", ["is_connected"] + ) + op.create_index( + "ix_evaluator_connections_heartbeat", + "evaluator_connections", + ["last_heartbeat"], + ) + + +def downgrade() -> None: + op.drop_index( + "ix_evaluator_connections_heartbeat", table_name="evaluator_connections" + ) + op.drop_index( + "ix_evaluator_connections_connected", table_name="evaluator_connections" + ) + op.drop_table("evaluator_connections") diff --git a/src/backend/migrations/versions/021_add_evaluator_role.py b/src/backend/migrations/versions/021_add_evaluator_role.py new file mode 100644 index 0000000..7b83249 --- /dev/null +++ b/src/backend/migrations/versions/021_add_evaluator_role.py @@ -0,0 +1,35 @@ +"""Add evaluator role to api_keys constraint. + +Revision ID: 021_add_evaluator_role +Revises: 020_add_evaluator_connections +Create Date: 2025-01-15 00:00:00 +""" + +from __future__ import annotations + +from alembic import op + +revision = "021_add_evaluator_role" +down_revision = "020_add_evaluator_connections" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add 'evaluator' to the valid roles for api_keys.""" + op.drop_constraint("ck_api_keys_valid_role", "api_keys", type_="check") + op.create_check_constraint( + "ck_api_keys_valid_role", + "api_keys", + "role IN ('admin', 'validator', 'evaluator', 'viewer')", + ) + + +def downgrade() -> None: + """Remove 'evaluator' from the valid roles for api_keys.""" + op.drop_constraint("ck_api_keys_valid_role", "api_keys", type_="check") + op.create_check_constraint( + "ck_api_keys_valid_role", + "api_keys", + "role IN ('admin', 'validator', 'viewer')", + ) diff --git a/src/backend/migrations/versions/022_add_competition_task_type.py b/src/backend/migrations/versions/022_add_competition_task_type.py new file mode 100644 index 0000000..351ba0b --- /dev/null +++ b/src/backend/migrations/versions/022_add_competition_task_type.py @@ -0,0 +1,34 @@ +"""Add task_type column to competitions for executor dispatch. + +Revision ID: 022_add_competition_task_type +Revises: 021_add_evaluator_role +Create Date: 2025-01-15 00:00:00 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +revision = "022_add_competition_task_type" +down_revision = "021_add_evaluator_role" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Add task_type column for executor dispatch.""" + op.add_column( + "competitions", + sa.Column( + "task_type", + sa.String(64), + nullable=False, + server_default="rl_rollout", + ), + ) + + +def downgrade() -> None: + """Remove task_type column from competitions.""" + op.drop_column("competitions", "task_type") diff --git a/src/backend/models.py b/src/backend/models.py index df59a79..81acb17 100644 --- a/src/backend/models.py +++ b/src/backend/models.py @@ -208,6 +208,28 @@ class Config: from_attributes = True +class EvaluatorInfoResponse(SQLModel): + """Response model for evaluator information.""" + + evaluator_id: str + api_key_id: Optional[str] + supported_task_types: List[str] + max_concurrent_jobs: int + current_job_count: int + is_connected: bool + first_connected_at: datetime + last_heartbeat: datetime + total_jobs_assigned: int + total_jobs_completed: int + total_jobs_failed: int + capabilities: Optional[dict] + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + class MinerSubmissionResponse(SQLModel): """Response model for miner submission data.""" @@ -525,6 +547,14 @@ class Competition(TimestampMixin, SQLModel, table=True): sa_column_kwargs={"server_default": "true"}, ) + # Task type for executor dispatch (defaults to rl_rollout for existing competitions) + task_type: str = Field( + default="rl_rollout", + max_length=64, + nullable=False, + sa_column_kwargs={"server_default": "'rl_rollout'"}, + ) + # Start and end times for the competition start_time: Optional[datetime] = Field( default=None, sa_column=Column(SADateTime(timezone=True), nullable=True) @@ -1112,6 +1142,93 @@ class ValidatorConnection(TimestampMixin, SQLModel, table=True): ) +class EvaluatorConnection(TimestampMixin, SQLModel, table=True): + """Track evaluator connections and their capabilities for direct backend communication.""" + + __tablename__ = "evaluator_connections" + + id: int = Field(sa_column=Column(BigInteger, primary_key=True)) + + # Unique evaluator instance identifier + evaluator_id: str = Field(max_length=128, nullable=False, unique=True, index=True) + + # Link to API key used for authentication + api_key_id: Optional[int] = Field( + sa_column=Column( + BigInteger, ForeignKey("api_keys.id"), nullable=True, index=True + ) + ) + + # Capabilities + supported_task_types: List[str] = Field( + sa_column=Column(JSON, nullable=False), + default_factory=lambda: ["rl_rollout"], + ) + max_concurrent_jobs: int = Field( + default=1, nullable=False, sa_column_kwargs={"server_default": "1"} + ) + current_job_count: int = Field( + default=0, nullable=False, sa_column_kwargs={"server_default": "0"} + ) + + # Connection state + is_connected: bool = Field( + default=True, + nullable=False, + index=True, + sa_column_kwargs={"server_default": "true"}, + ) + last_heartbeat: datetime = Field( + sa_column=Column( + SADateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + first_connected_at: datetime = Field( + sa_column=Column( + SADateTime(timezone=True), nullable=False, server_default=func.now() + ) + ) + + # Statistics + total_jobs_assigned: int = Field( + default=0, nullable=False, sa_column_kwargs={"server_default": "0"} + ) + total_jobs_completed: int = Field( + default=0, nullable=False, sa_column_kwargs={"server_default": "0"} + ) + total_jobs_failed: int = Field( + default=0, nullable=False, sa_column_kwargs={"server_default": "0"} + ) + + # Resource metadata (GPU count, memory, etc.) + capabilities: Optional[dict] = Field( + default=None, sa_column=Column(JSON, nullable=True) + ) + + # Relationships + api_key: Optional["ApiKey"] = Relationship(back_populates="evaluator_connections") + + __table_args__ = ( + CheckConstraint( + "current_job_count >= 0", name="ck_evaluator_concurrent_non_negative" + ), + CheckConstraint( + "max_concurrent_jobs > 0", name="ck_evaluator_max_concurrent_positive" + ), + CheckConstraint( + "total_jobs_assigned >= 0", name="ck_evaluator_jobs_assigned_non_negative" + ), + CheckConstraint( + "total_jobs_completed >= 0", name="ck_evaluator_jobs_completed_non_negative" + ), + CheckConstraint( + "total_jobs_failed >= 0", name="ck_evaluator_jobs_failed_non_negative" + ), + Index("ix_evaluator_connections_connected", "is_connected"), + Index("ix_evaluator_connections_heartbeat", "last_heartbeat"), + ) + + class BackendState(TimestampMixin, SQLModel, table=True): """Backend service state for persistence across restarts.""" @@ -1189,6 +1306,9 @@ class ApiKey(TimestampMixin, SQLModel, table=True): validator_connections: List["ValidatorConnection"] = Relationship( back_populates="api_key", cascade_delete=True ) + evaluator_connections: List["EvaluatorConnection"] = Relationship( + back_populates="api_key", cascade_delete=True + ) reviewed_leader_candidates: List["CompetitionLeaderCandidate"] = Relationship( back_populates="reviewed_by" ) @@ -1196,8 +1316,10 @@ class ApiKey(TimestampMixin, SQLModel, table=True): __table_args__ = ( Index("ix_api_keys_active", "is_active"), Index("ix_api_keys_expires", "expires_at"), + # Note: 'evaluator' role added for direct evaluator connections CheckConstraint( - "role IN ('admin', 'validator', 'viewer')", name="ck_api_keys_valid_role" + "role IN ('admin', 'validator', 'evaluator', 'viewer')", + name="ck_api_keys_valid_role", ), ) diff --git a/src/backend/scoring/__init__.py b/src/backend/scoring/__init__.py new file mode 100644 index 0000000..0969649 --- /dev/null +++ b/src/backend/scoring/__init__.py @@ -0,0 +1,35 @@ +""" +Scoring strategy abstraction for Kinitro. + +This package provides pluggable scoring strategies that allow different +task types to have their own eligibility, metrics extraction, and scoring logic. + +It also re-exports the ScoringEngine and ScoringConfig for backward compatibility. +""" + +from .registry import ScoringStrategyRegistry +from .strategies import ( + EligibilityResult, + RLRolloutScoringStrategy, + ScoringMetrics, + ScoringStrategy, + StrategyNotFoundError, +) + +# Re-export from the scoring_engine module for backward compatibility +# The scoring_engine module was previously named scoring.py but was renamed +# to avoid conflicts with this package +from backend.scoring_engine import ScoringConfig, ScoringEngine + +__all__ = [ + # Strategy abstractions + "EligibilityResult", + "RLRolloutScoringStrategy", + "ScoringMetrics", + "ScoringStrategy", + "ScoringStrategyRegistry", + "StrategyNotFoundError", + # Backward compatible re-exports + "ScoringConfig", + "ScoringEngine", +] diff --git a/src/backend/scoring/registry.py b/src/backend/scoring/registry.py new file mode 100644 index 0000000..fbf73dc --- /dev/null +++ b/src/backend/scoring/registry.py @@ -0,0 +1,152 @@ +""" +Scoring strategy registry for task type dispatch. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.log import get_logger +from core.tasks import TaskType + +from .strategies import ( + RLRolloutScoringStrategy, + ScoringStrategy, + StrategyNotFoundError, +) + +if TYPE_CHECKING: + pass + +logger = get_logger(__name__) + + +class ScoringStrategyRegistry: + """Registry for scoring strategies. + + Maps task types to their corresponding scoring strategies. + Strategies must implement the ScoringStrategy protocol. + + Example usage: + # Get the default registry with all built-in strategies + registry = ScoringStrategyRegistry.default() + + # Get a strategy for a task type + strategy = registry.get(TaskType.RL_ROLLOUT) + + # Register a custom strategy + registry.register(MyCustomStrategy()) + """ + + _strategies: dict[TaskType, ScoringStrategy] + + def __init__(self) -> None: + """Create an empty registry.""" + self._strategies = {} + + @classmethod + def default(cls) -> "ScoringStrategyRegistry": + """Create a registry with all built-in strategies registered.""" + registry = cls() + registry.register(RLRolloutScoringStrategy()) + return registry + + def register(self, strategy: ScoringStrategy) -> None: + """Register a scoring strategy for its task type. + + Args: + strategy: The strategy to register. Must implement ScoringStrategy protocol. + + Raises: + TypeError: If strategy doesn't implement ScoringStrategy protocol. + """ + if not isinstance(strategy, ScoringStrategy): + raise TypeError( + f"Strategy must implement ScoringStrategy protocol, got {type(strategy)}" + ) + + task_type = strategy.task_type + if task_type in self._strategies: + logger.warning( + "Overwriting existing strategy for task type %s: %s -> %s", + task_type, + type(self._strategies[task_type]).__name__, + type(strategy).__name__, + ) + + self._strategies[task_type] = strategy + logger.info( + "Registered scoring strategy for task type %s: %s", + task_type, + type(strategy).__name__, + ) + + def get(self, task_type: TaskType | str) -> ScoringStrategy: + """Get the scoring strategy for a task type. + + Args: + task_type: The task type to get a strategy for + + Returns: + The registered ScoringStrategy + + Raises: + StrategyNotFoundError: If no strategy is registered for the task type + """ + # Convert string to TaskType if needed + if isinstance(task_type, str): + try: + task_type = TaskType(task_type) + except ValueError: + raise StrategyNotFoundError(task_type) + + strategy = self._strategies.get(task_type) + if strategy is None: + raise StrategyNotFoundError(task_type) + + return strategy + + def has(self, task_type: TaskType | str) -> bool: + """Check if a strategy is registered for a task type. + + Args: + task_type: The task type to check + + Returns: + True if a strategy is registered, False otherwise + """ + if isinstance(task_type, str): + try: + task_type = TaskType(task_type) + except ValueError: + return False + + return task_type in self._strategies + + def list_task_types(self) -> list[TaskType]: + """List all task types that have registered strategies. + + Returns: + List of TaskType values with registered strategies + """ + return list(self._strategies.keys()) + + def unregister(self, task_type: TaskType) -> bool: + """Unregister a strategy for a task type. + + Args: + task_type: The task type to unregister + + Returns: + True if a strategy was unregistered, False if none was registered + """ + if task_type in self._strategies: + del self._strategies[task_type] + logger.info("Unregistered scoring strategy for task type %s", task_type) + return True + return False + + def clear(self) -> None: + """Remove all registered strategies.""" + self._strategies.clear() + logger.info("Cleared all scoring strategies") diff --git a/src/backend/scoring/strategies.py b/src/backend/scoring/strategies.py new file mode 100644 index 0000000..5338a9f --- /dev/null +++ b/src/backend/scoring/strategies.py @@ -0,0 +1,267 @@ +""" +Scoring strategy interfaces and implementations. + +Each task type can have its own scoring strategy that defines how to: +- Extract metrics from task results +- Check eligibility against competition thresholds +- Compute final scores +- Compare results for ranking +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from core.tasks import TaskResult, TaskType + +if TYPE_CHECKING: + from backend.models import BackendEvaluationResult, Competition + + +class StrategyNotFoundError(Exception): + """Raised when no strategy is registered for a task type.""" + + def __init__(self, task_type: TaskType | str): + self.task_type = task_type + super().__init__(f"No scoring strategy registered for task type: {task_type}") + + +@dataclass +class ScoringMetrics: + """Container for extracted metrics from a task result. + + This provides a typed container for the most common metrics while + allowing task-specific extras via the `extra` dict. + """ + + # Common metrics used for eligibility/ranking + success_rate: float | None = None + avg_reward: float | None = None + total_episodes: int | None = None + score: float | None = None + + # Task-specific additional metrics + extra: dict[str, Any] | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for storage/serialization.""" + result: dict[str, Any] = {} + if self.success_rate is not None: + result["success_rate"] = self.success_rate + if self.avg_reward is not None: + result["avg_reward"] = self.avg_reward + if self.total_episodes is not None: + result["total_episodes"] = self.total_episodes + if self.score is not None: + result["score"] = self.score + if self.extra: + result.update(self.extra) + return result + + +@dataclass +class EligibilityResult: + """Result of an eligibility check.""" + + eligible: bool + reason: str | None = None + + +@runtime_checkable +class ScoringStrategy(Protocol): + """Interface for task-type-specific scoring logic. + + Implement this protocol to add scoring support for new task types. + The strategy handles all task-specific logic for: + - Extracting scoreable metrics from results + - Checking if results meet competition thresholds + - Computing final scores + - Comparing results for ranking + """ + + @property + def task_type(self) -> TaskType: + """The task type this strategy handles.""" + ... + + def extract_metrics(self, result: BackendEvaluationResult) -> ScoringMetrics: + """Extract scoreable metrics from an evaluation result. + + Args: + result: The evaluation result to extract metrics from + + Returns: + ScoringMetrics with extracted values + """ + ... + + def extract_metrics_from_task_result(self, result: TaskResult) -> ScoringMetrics: + """Extract scoreable metrics from a TaskResult. + + This is useful during evaluation when we have TaskResult + but not yet a BackendEvaluationResult. + + Args: + result: The task result to extract metrics from + + Returns: + ScoringMetrics with extracted values + """ + ... + + def check_eligibility( + self, + metrics: ScoringMetrics, + competition: Competition, + ) -> EligibilityResult: + """Check if metrics meet competition eligibility thresholds. + + Args: + metrics: The extracted metrics to check + competition: The competition with threshold configuration + + Returns: + EligibilityResult indicating eligibility and reason if not + """ + ... + + def compute_score( + self, + metrics: ScoringMetrics, + competition: Competition, + ) -> float: + """Compute a final score from metrics. + + This score is used for ranking and may be a combination of + multiple metrics depending on the task type. + + Args: + metrics: The extracted metrics + competition: Competition with scoring configuration + + Returns: + Final computed score + """ + ... + + def compare( + self, + a: ScoringMetrics, + b: ScoringMetrics, + ) -> int: + """Compare two results for ranking. + + Args: + a: First set of metrics + b: Second set of metrics + + Returns: + -1 if a < b, 0 if equal, 1 if a > b + """ + ... + + +class RLRolloutScoringStrategy: + """Scoring strategy for RL rollout tasks. + + Uses success_rate as primary metric and avg_reward as secondary. + Eligibility is determined by min_success_rate and min_avg_reward thresholds. + """ + + @property + def task_type(self) -> TaskType: + return TaskType.RL_ROLLOUT + + def extract_metrics(self, result: BackendEvaluationResult) -> ScoringMetrics: + """Extract RL-specific metrics from evaluation result.""" + return ScoringMetrics( + success_rate=result.success_rate, + avg_reward=result.avg_reward, + total_episodes=result.total_episodes, + score=result.score, + ) + + def extract_metrics_from_task_result(self, result: TaskResult) -> ScoringMetrics: + """Extract RL-specific metrics from TaskResult.""" + return ScoringMetrics( + success_rate=result.metrics.get("success_rate"), + avg_reward=result.metrics.get("avg_reward"), + total_episodes=result.total_episodes, + score=result.metrics.get("score"), + ) + + def check_eligibility( + self, + metrics: ScoringMetrics, + competition: Competition, + ) -> EligibilityResult: + """Check RL eligibility based on success rate and avg reward thresholds.""" + if metrics.success_rate is None or metrics.avg_reward is None: + return EligibilityResult( + eligible=False, + reason="Missing required metrics (success_rate or avg_reward)", + ) + + if metrics.success_rate < competition.min_success_rate: + return EligibilityResult( + eligible=False, + reason=( + f"success_rate {metrics.success_rate:.3f} below " + f"threshold {competition.min_success_rate:.3f}" + ), + ) + + if metrics.avg_reward < competition.min_avg_reward: + return EligibilityResult( + eligible=False, + reason=( + f"avg_reward {metrics.avg_reward:.3f} below " + f"threshold {competition.min_avg_reward}" + ), + ) + + return EligibilityResult(eligible=True) + + def compute_score( + self, + metrics: ScoringMetrics, + competition: Competition, + ) -> float: + """Compute score for RL tasks. + + Currently uses success_rate as the primary score metric. + Could be extended to use weighted combinations based on + competition.scoring_config. + """ + # Use the pre-computed score if available + if metrics.score is not None: + return metrics.score + + # Otherwise, use success_rate as the score + if metrics.success_rate is not None: + return metrics.success_rate + + return 0.0 + + def compare( + self, + a: ScoringMetrics, + b: ScoringMetrics, + ) -> int: + """Compare RL results by success_rate, then avg_reward.""" + # Primary: success_rate (higher is better) + a_sr = a.success_rate if a.success_rate is not None else float("-inf") + b_sr = b.success_rate if b.success_rate is not None else float("-inf") + + if a_sr != b_sr: + return 1 if a_sr > b_sr else -1 + + # Secondary: avg_reward (higher is better) + a_ar = a.avg_reward if a.avg_reward is not None else float("-inf") + b_ar = b.avg_reward if b.avg_reward is not None else float("-inf") + + if a_ar != b_ar: + return 1 if a_ar > b_ar else -1 + + return 0 diff --git a/src/backend/scoring/tests/__init__.py b/src/backend/scoring/tests/__init__.py new file mode 100644 index 0000000..42e97cd --- /dev/null +++ b/src/backend/scoring/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for scoring strategies.""" diff --git a/src/backend/scoring/tests/test_strategies.py b/src/backend/scoring/tests/test_strategies.py new file mode 100644 index 0000000..9e04163 --- /dev/null +++ b/src/backend/scoring/tests/test_strategies.py @@ -0,0 +1,371 @@ +"""Tests for scoring strategies.""" + +from unittest.mock import MagicMock + +import pytest + +from core.tasks import TaskResult, TaskType + +from backend.scoring import ( + EligibilityResult, + RLRolloutScoringStrategy, + ScoringMetrics, + ScoringStrategyRegistry, + StrategyNotFoundError, +) + + +# ============================================================================= +# ScoringMetrics Tests +# ============================================================================= + + +class TestScoringMetrics: + """Tests for ScoringMetrics dataclass.""" + + def test_to_dict_all_fields(self): + """Test serialization with all fields set.""" + metrics = ScoringMetrics( + success_rate=0.85, + avg_reward=100.5, + total_episodes=50, + score=0.85, + extra={"custom_metric": 42.0}, + ) + result = metrics.to_dict() + + assert result["success_rate"] == 0.85 + assert result["avg_reward"] == 100.5 + assert result["total_episodes"] == 50 + assert result["score"] == 0.85 + assert result["custom_metric"] == 42.0 + + def test_to_dict_partial_fields(self): + """Test serialization with only some fields set.""" + metrics = ScoringMetrics(success_rate=0.5) + result = metrics.to_dict() + + assert result == {"success_rate": 0.5} + assert "avg_reward" not in result + assert "total_episodes" not in result + + def test_to_dict_empty(self): + """Test serialization with no fields set.""" + metrics = ScoringMetrics() + result = metrics.to_dict() + + assert result == {} + + +# ============================================================================= +# RLRolloutScoringStrategy Tests +# ============================================================================= + + +class TestRLRolloutScoringStrategy: + """Tests for RLRolloutScoringStrategy.""" + + @pytest.fixture + def strategy(self): + """Create a strategy instance for testing.""" + return RLRolloutScoringStrategy() + + @pytest.fixture + def mock_competition(self): + """Create a mock competition with thresholds.""" + comp = MagicMock() + comp.id = "test-competition" + comp.min_success_rate = 0.5 + comp.min_avg_reward = 10.0 + comp.task_type = "rl_rollout" + return comp + + @pytest.fixture + def mock_result(self): + """Create a mock BackendEvaluationResult.""" + result = MagicMock() + result.success_rate = 0.75 + result.avg_reward = 50.0 + result.total_episodes = 100 + result.score = 0.75 + return result + + def test_task_type(self, strategy): + """Test that task_type property returns correct value.""" + assert strategy.task_type == TaskType.RL_ROLLOUT + + def test_extract_metrics(self, strategy, mock_result): + """Test extracting metrics from BackendEvaluationResult.""" + metrics = strategy.extract_metrics(mock_result) + + assert metrics.success_rate == 0.75 + assert metrics.avg_reward == 50.0 + assert metrics.total_episodes == 100 + assert metrics.score == 0.75 + + def test_extract_metrics_with_none_values(self, strategy): + """Test extracting metrics when some values are None.""" + result = MagicMock() + result.success_rate = None + result.avg_reward = None + result.total_episodes = None + result.score = None + + metrics = strategy.extract_metrics(result) + + assert metrics.success_rate is None + assert metrics.avg_reward is None + assert metrics.total_episodes is None + assert metrics.score is None + + def test_extract_metrics_from_task_result(self, strategy): + """Test extracting metrics from TaskResult.""" + task_result = TaskResult( + task_id="test-task", + success=True, + metrics={ + "success_rate": 0.9, + "avg_reward": 75.0, + "score": 0.9, + }, + total_episodes=50, + ) + + metrics = strategy.extract_metrics_from_task_result(task_result) + + assert metrics.success_rate == 0.9 + assert metrics.avg_reward == 75.0 + assert metrics.total_episodes == 50 + assert metrics.score == 0.9 + + # ------------------------------------------------------------------------- + # Eligibility Tests + # ------------------------------------------------------------------------- + + def test_check_eligibility_eligible(self, strategy, mock_competition): + """Test eligibility check for eligible result.""" + metrics = ScoringMetrics(success_rate=0.8, avg_reward=50.0) + + result = strategy.check_eligibility(metrics, mock_competition) + + assert result.eligible is True + assert result.reason is None + + def test_check_eligibility_below_success_rate(self, strategy, mock_competition): + """Test eligibility when success_rate is below threshold.""" + metrics = ScoringMetrics(success_rate=0.3, avg_reward=50.0) + + result = strategy.check_eligibility(metrics, mock_competition) + + assert result.eligible is False + assert "success_rate" in result.reason + assert "0.300" in result.reason + + def test_check_eligibility_below_avg_reward(self, strategy, mock_competition): + """Test eligibility when avg_reward is below threshold.""" + metrics = ScoringMetrics(success_rate=0.8, avg_reward=5.0) + + result = strategy.check_eligibility(metrics, mock_competition) + + assert result.eligible is False + assert "avg_reward" in result.reason + + def test_check_eligibility_missing_metrics(self, strategy, mock_competition): + """Test eligibility when required metrics are missing.""" + metrics = ScoringMetrics(success_rate=None, avg_reward=None) + + result = strategy.check_eligibility(metrics, mock_competition) + + assert result.eligible is False + assert "Missing required metrics" in result.reason + + def test_check_eligibility_at_threshold(self, strategy, mock_competition): + """Test eligibility at exact threshold values.""" + metrics = ScoringMetrics(success_rate=0.5, avg_reward=10.0) + + result = strategy.check_eligibility(metrics, mock_competition) + + assert result.eligible is True + + # ------------------------------------------------------------------------- + # Score Computation Tests + # ------------------------------------------------------------------------- + + def test_compute_score_uses_existing_score(self, strategy, mock_competition): + """Test that compute_score uses existing score when available.""" + metrics = ScoringMetrics(success_rate=0.8, score=0.95) + + score = strategy.compute_score(metrics, mock_competition) + + assert score == 0.95 + + def test_compute_score_falls_back_to_success_rate(self, strategy, mock_competition): + """Test that compute_score falls back to success_rate when score is None.""" + metrics = ScoringMetrics(success_rate=0.8, score=None) + + score = strategy.compute_score(metrics, mock_competition) + + assert score == 0.8 + + def test_compute_score_returns_zero_when_no_metrics( + self, strategy, mock_competition + ): + """Test that compute_score returns 0 when no relevant metrics.""" + metrics = ScoringMetrics() + + score = strategy.compute_score(metrics, mock_competition) + + assert score == 0.0 + + # ------------------------------------------------------------------------- + # Comparison Tests + # ------------------------------------------------------------------------- + + def test_compare_by_success_rate(self, strategy): + """Test comparison primarily by success_rate.""" + a = ScoringMetrics(success_rate=0.9, avg_reward=50.0) + b = ScoringMetrics(success_rate=0.7, avg_reward=100.0) + + assert strategy.compare(a, b) == 1 # a > b + assert strategy.compare(b, a) == -1 # b < a + + def test_compare_by_avg_reward_when_success_rate_equal(self, strategy): + """Test comparison by avg_reward when success_rate is equal.""" + a = ScoringMetrics(success_rate=0.8, avg_reward=100.0) + b = ScoringMetrics(success_rate=0.8, avg_reward=50.0) + + assert strategy.compare(a, b) == 1 # a > b + assert strategy.compare(b, a) == -1 # b < a + + def test_compare_equal_metrics(self, strategy): + """Test comparison when both metrics are equal.""" + a = ScoringMetrics(success_rate=0.8, avg_reward=50.0) + b = ScoringMetrics(success_rate=0.8, avg_reward=50.0) + + assert strategy.compare(a, b) == 0 + + def test_compare_handles_none_values(self, strategy): + """Test comparison handles None values gracefully.""" + a = ScoringMetrics(success_rate=0.5, avg_reward=None) + b = ScoringMetrics(success_rate=None, avg_reward=100.0) + + assert strategy.compare(a, b) == 1 # 0.5 > -inf + assert strategy.compare(b, a) == -1 + + +# ============================================================================= +# ScoringStrategyRegistry Tests +# ============================================================================= + + +class TestScoringStrategyRegistry: + """Tests for ScoringStrategyRegistry.""" + + def test_default_registry_has_rl_rollout(self): + """Test that default registry has RL rollout strategy registered.""" + registry = ScoringStrategyRegistry.default() + + strategy = registry.get(TaskType.RL_ROLLOUT) + assert isinstance(strategy, RLRolloutScoringStrategy) + + def test_get_with_string_task_type(self): + """Test getting strategy with string task type.""" + registry = ScoringStrategyRegistry.default() + + strategy = registry.get("rl_rollout") + assert isinstance(strategy, RLRolloutScoringStrategy) + + def test_get_raises_for_unknown_task_type(self): + """Test that get raises StrategyNotFoundError for unknown types.""" + registry = ScoringStrategyRegistry() + + with pytest.raises(StrategyNotFoundError) as exc_info: + registry.get(TaskType.RL_ROLLOUT) + + assert exc_info.value.task_type == TaskType.RL_ROLLOUT + + def test_get_raises_for_invalid_string_task_type(self): + """Test that get raises for invalid string task types.""" + registry = ScoringStrategyRegistry() + + with pytest.raises(StrategyNotFoundError) as exc_info: + registry.get("invalid_type") + + assert exc_info.value.task_type == "invalid_type" + + def test_has_returns_true_for_registered(self): + """Test that has returns True for registered task types.""" + registry = ScoringStrategyRegistry.default() + + assert registry.has(TaskType.RL_ROLLOUT) is True + assert registry.has("rl_rollout") is True + + def test_has_returns_false_for_unregistered(self): + """Test that has returns False for unregistered task types.""" + registry = ScoringStrategyRegistry() + + assert registry.has(TaskType.RL_ROLLOUT) is False + assert registry.has("invalid_type") is False + + def test_register_and_get(self): + """Test registering and retrieving a strategy.""" + registry = ScoringStrategyRegistry() + strategy = RLRolloutScoringStrategy() + + registry.register(strategy) + retrieved = registry.get(TaskType.RL_ROLLOUT) + + assert retrieved is strategy + + def test_register_overwrites_existing(self): + """Test that registering overwrites existing strategy.""" + registry = ScoringStrategyRegistry() + strategy1 = RLRolloutScoringStrategy() + strategy2 = RLRolloutScoringStrategy() + + registry.register(strategy1) + registry.register(strategy2) + retrieved = registry.get(TaskType.RL_ROLLOUT) + + assert retrieved is strategy2 + + def test_register_raises_for_non_protocol(self): + """Test that register raises for non-ScoringStrategy objects.""" + registry = ScoringStrategyRegistry() + + with pytest.raises(TypeError): + registry.register("not a strategy") # type: ignore + + def test_list_task_types(self): + """Test listing registered task types.""" + registry = ScoringStrategyRegistry.default() + + task_types = registry.list_task_types() + + assert TaskType.RL_ROLLOUT in task_types + + def test_unregister(self): + """Test unregistering a strategy.""" + registry = ScoringStrategyRegistry.default() + + result = registry.unregister(TaskType.RL_ROLLOUT) + + assert result is True + assert registry.has(TaskType.RL_ROLLOUT) is False + + def test_unregister_nonexistent(self): + """Test unregistering a strategy that doesn't exist.""" + registry = ScoringStrategyRegistry() + + result = registry.unregister(TaskType.RL_ROLLOUT) + + assert result is False + + def test_clear(self): + """Test clearing all strategies.""" + registry = ScoringStrategyRegistry.default() + assert len(registry.list_task_types()) > 0 + + registry.clear() + + assert len(registry.list_task_types()) == 0 diff --git a/src/backend/scoring_engine.py b/src/backend/scoring_engine.py new file mode 100644 index 0000000..5ee6b7c --- /dev/null +++ b/src/backend/scoring_engine.py @@ -0,0 +1,515 @@ +""" +Scoring engine for Kinitro evaluations. + +Handles all scoring, eligibility checking, and leader candidate logic. +Extracted from BackendService for better separation of concerns. + +The ScoringEngine uses pluggable ScoringStrategy implementations to support +different task types with their own eligibility and scoring logic. +""" + +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +from fiber.chain.models import Node +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from core.log import get_logger +from core.tasks import TaskType + +from .models import ( + BackendEvaluationResult, + Competition, + CompetitionLeaderCandidate, + LeaderCandidateStatus, + SS58Address, +) +from .scoring import ScoringStrategyRegistry + +if TYPE_CHECKING: + from .scoring import ScoringStrategy + +logger = get_logger(__name__) + + +class ScoringConfig: + """Configuration for the scoring engine.""" + + def __init__( + self, + owner_uid: int = 4, + burn_pct: float = 0.98, + ): + self.owner_uid = owner_uid + self.burn_pct = self._validate_burn_pct(burn_pct) + + @staticmethod + def _validate_burn_pct(burn_pct: float) -> float: + """Validate and clamp burn percentage to [0, 1].""" + if burn_pct < 0 or burn_pct > 1: + logger.warning( + "Configured burn_pct %.3f out of bounds [0, 1]; clamping.", + burn_pct, + ) + return max(0.0, min(1.0, burn_pct)) + return burn_pct + + +class ScoringEngine: + """ + Handles all scoring, eligibility, and leader candidate logic. + + This class is responsible for: + - Checking if miners meet eligibility criteria for competitions + - Creating and managing leader candidates + - Computing scores for competition winners + - Calculating weight distributions for miners + + The ScoringEngine uses pluggable ScoringStrategy implementations to support + different task types. The strategy is selected based on competition.task_type. + """ + + def __init__( + self, + session_factory: async_sessionmaker[AsyncSession], + config: ScoringConfig, + id_generator, + strategy_registry: ScoringStrategyRegistry | None = None, + ): + self.session_factory = session_factory + self.config = config + self.id_generator = id_generator + # Use provided registry or create default with all built-in strategies + self.strategy_registry = strategy_registry or ScoringStrategyRegistry.default() + + def get_strategy(self, competition: Competition) -> "ScoringStrategy": + """Get the scoring strategy for a competition based on its task type. + + Args: + competition: The competition to get a strategy for + + Returns: + The appropriate ScoringStrategy for the competition's task type + """ + return self.strategy_registry.get(competition.task_type) + + def is_eligible( + self, + result: BackendEvaluationResult, + competition: Competition, + ) -> bool: + """Check if a miner meets eligibility criteria for a competition. + + Uses the competition's task type to select the appropriate scoring + strategy for eligibility checking. + """ + strategy = self.get_strategy(competition) + metrics = strategy.extract_metrics(result) + eligibility = strategy.check_eligibility(metrics, competition) + + if not eligibility.eligible and eligibility.reason: + logger.debug( + "Miner %s excluded from competition %s: %s", + result.miner_hotkey, + competition.id, + eligibility.reason, + ) + + return eligibility.eligible + + async def queue_leader_candidate( + self, + session: AsyncSession, + competition: Competition, + result: BackendEvaluationResult, + ) -> bool: + """Persist a leader candidate if not already recorded for this result.""" + if result.avg_reward is None: + logger.debug( + "Skipping leader candidate creation without avg_reward: competition=%s result_id=%s", + competition.id, + result.id, + ) + return False + + existing_candidate_result = await session.execute( + select(CompetitionLeaderCandidate).where( + CompetitionLeaderCandidate.evaluation_result_id == result.id + ) + ) + existing_candidate = existing_candidate_result.scalar_one_or_none() + if existing_candidate: + logger.debug( + "Leader candidate already exists for evaluation result %s (competition=%s)", + result.id, + competition.id, + ) + return False + + candidate = CompetitionLeaderCandidate( + id=next(self.id_generator), + competition_id=competition.id, + miner_hotkey=result.miner_hotkey, + evaluation_result_id=result.id, + avg_reward=result.avg_reward, + success_rate=result.success_rate, + score=result.score, + total_episodes=result.total_episodes, + ) + session.add(candidate) + return True + + async def score_evaluations(self) -> dict[SS58Address, float]: + """ + Score completed evaluations with winner-takes-all per competition. + + Scoring logic: + - Miners must meet minimum success rate threshold per competition + - Miners must pass minimum avg reward threshold per competition + - Eligible challengers above the approved leader's success rate are queued for admin review + - If the current leader improves or matches their approved success rate, that result is also queued + - Current leader retains position until admin approval + - Each miner can only win ONE competition (first-win policy) + - Final scores are normalized based on competition points + + Returns: + dict[SS58Address, float]: Mapping of miner hotkeys to their normalized scores (0-1). + """ + async with self.session_factory() as session: + # Fetch all active competitions + competitions_result = await session.execute( + select(Competition).where(Competition.active) + ) + competitions = competitions_result.scalars().all() + + if not competitions: + logger.info("No active competitions found for scoring") + return {} + + # Calculate total points across all competitions + total_points = sum(comp.points for comp in competitions) + + # Dictionary to store winner scores + miner_scores: dict[SS58Address, float] = {} + + for competition in competitions: + await self._score_competition( + session, competition, total_points, miner_scores + ) + + # Commit any leader updates to database + await session.commit() + + # Log final scores + if miner_scores: + logger.info(f"Final miner scores: {len(miner_scores)} miners scored") + for hotkey, score in sorted( + miner_scores.items(), key=lambda x: x[1], reverse=True + )[:10]: + logger.info(f" {hotkey}: {score:.4f}") + else: + logger.info("No miners received scores") + + return miner_scores + + async def _score_competition( + self, + session: AsyncSession, + competition: Competition, + total_points: int, + miner_scores: dict[SS58Address, float], + ) -> None: + """Score a single competition and update miner_scores in place.""" + # Get the scoring strategy for this competition's task type + strategy = self.get_strategy(competition) + + # Get all evaluation results for this competition + results_query = select(BackendEvaluationResult).where( + BackendEvaluationResult.competition_id == competition.id + ) + results = await session.execute(results_query) + eval_results = results.scalars().all() + + if not eval_results: + logger.debug(f"No evaluation results for competition {competition.id}") + return + + # Find eligible challengers and order them using the strategy's compare method + eligible_results = [ + result for result in eval_results if self.is_eligible(result, competition) + ] + + # Sort using strategy's compare method (descending order, best first) + from functools import cmp_to_key + + def compare_results( + a: BackendEvaluationResult, b: BackendEvaluationResult + ) -> int: + metrics_a = strategy.extract_metrics(a) + metrics_b = strategy.extract_metrics(b) + # Negate because we want descending order (best first) + return -strategy.compare(metrics_a, metrics_b) + + eligible_results.sort(key=cmp_to_key(compare_results)) + + if not eligible_results: + if competition.current_leader_hotkey: + logger.info( + "Competition %s: Current leader %s retains position (no eligible challengers)", + competition.id, + competition.current_leader_hotkey, + ) + else: + logger.info("Competition %s: No eligible miners found", competition.id) + return + + current_leader = competition.current_leader_hotkey + + if current_leader is None: + await self._handle_no_leader(session, competition, eligible_results) + else: + await self._handle_existing_leader( + session, competition, current_leader, eligible_results + ) + + # Award points only to the currently approved leader + award_hotkey = competition.current_leader_hotkey + if not award_hotkey: + logger.debug( + "Competition %s: Skipping score award (no approved leader)", + competition.id, + ) + return + + base_score = competition.points / total_points if total_points else 0 + if base_score == 0: + logger.debug( + "Competition %s: Skipping zero-point competition in scoring", + competition.id, + ) + return + + if award_hotkey in miner_scores: + logger.warning( + "Miner %s already won competition - skipping score from %s. Previous score: %.4f, would have added: %.4f", + award_hotkey, + competition.id, + miner_scores[award_hotkey], + base_score * (1 - self.config.burn_pct), + ) + return + + awarded_score = base_score * (1 - self.config.burn_pct) + burned_score = base_score - awarded_score + + if awarded_score <= 0: + logger.info( + "Competition %s: Burned entire %.4f normalized score for %s (burn_pct=%.2f%%)", + competition.id, + base_score, + award_hotkey, + self.config.burn_pct * 100, + ) + return + + miner_scores[award_hotkey] = awarded_score + if burned_score > 0: + logger.info( + "Competition %s: Awarded %.4f normalized score to %s (burned %.4f; burn_pct=%.2f%%)", + competition.id, + awarded_score, + award_hotkey, + burned_score, + self.config.burn_pct * 100, + ) + else: + logger.info( + "Competition %s: Awarded %.4f normalized score to %s", + competition.id, + awarded_score, + award_hotkey, + ) + + async def _handle_no_leader( + self, + session: AsyncSession, + competition: Competition, + eligible_results: list[BackendEvaluationResult], + ) -> None: + """Handle scoring when there is no current leader.""" + queued_any = False + for res in eligible_results: + created_candidate = await self.queue_leader_candidate( + session, competition, res + ) + if created_candidate: + queued_any = True + logger.info( + "Competition %s: Queued leader candidate %s (success_rate=%.3f, avg_reward=%.3f)", + competition.id, + res.miner_hotkey, + res.success_rate or 0.0, + res.avg_reward or 0.0, + ) + if not queued_any: + logger.debug( + "Competition %s: All eligible results already queued as candidates", + competition.id, + ) + + async def _handle_existing_leader( + self, + session: AsyncSession, + competition: Competition, + current_leader: SS58Address, + eligible_results: list[BackendEvaluationResult], + ) -> None: + """Handle scoring when there is an existing leader.""" + leader_success_rate_stmt = ( + select( + CompetitionLeaderCandidate.success_rate, + CompetitionLeaderCandidate.evaluation_result_id, + ) + .where( + CompetitionLeaderCandidate.competition_id == competition.id, + CompetitionLeaderCandidate.miner_hotkey == current_leader, + CompetitionLeaderCandidate.status == LeaderCandidateStatus.APPROVED, + ) + .order_by( + CompetitionLeaderCandidate.reviewed_at.desc(), + CompetitionLeaderCandidate.updated_at.desc(), + ) + .limit(1) + ) + leader_success_rate_result = await session.execute(leader_success_rate_stmt) + leader_success_rate_row = leader_success_rate_result.first() + leader_success_rate = ( + leader_success_rate_row[0] if leader_success_rate_row else None + ) + leader_success_eval_id = ( + leader_success_rate_row[1] if leader_success_rate_row else None + ) + baseline_leader_success_rate = ( + leader_success_rate if leader_success_rate is not None else -1.0 + ) + + leader_best = next( + (res for res in eligible_results if res.miner_hotkey == current_leader), + None, + ) + if ( + leader_best + and leader_best.avg_reward is not None + and leader_best.avg_reward != competition.current_leader_reward + ): + competition.current_leader_reward = leader_best.avg_reward + competition.leader_updated_at = datetime.now(timezone.utc) + logger.info( + "Competition %s: Updated leader %s reward to %.3f", + competition.id, + current_leader, + leader_best.avg_reward, + ) + + challengers: list[BackendEvaluationResult] = [] + for res in eligible_results: + if ( + res.success_rate is not None + and res.success_rate > baseline_leader_success_rate + ): + challengers.append(res) + + if ( + leader_best + and leader_best.success_rate is not None + and leader_best.success_rate >= baseline_leader_success_rate + and leader_best.id != leader_success_eval_id + ): + challengers.append(leader_best) + + # Deduplicate challengers by evaluation_result_id while preserving order + seen_eval_ids: set[int] = set() + unique_challengers: list[BackendEvaluationResult] = [] + for res in challengers: + if res.id in seen_eval_ids: + continue + seen_eval_ids.add(res.id) + unique_challengers.append(res) + + for challenger in unique_challengers: + created_candidate = await self.queue_leader_candidate( + session, competition, challenger + ) + if created_candidate: + logger.info( + "Competition %s: Challenger %s queued for admin review (success_rate=%.3f, avg_reward=%.3f, current leader=%s success_rate=%s)", + competition.id, + challenger.miner_hotkey, + challenger.success_rate or 0.0, + challenger.avg_reward or 0.0, + current_leader, + f"{leader_success_rate:.3f}" + if leader_success_rate is not None + else "unknown", + ) + else: + logger.debug( + "Competition %s: Challenger %s already recorded as candidate", + competition.id, + challenger.miner_hotkey, + ) + + def compute_weights( + self, + miner_scores: dict[SS58Address, float], + nodes: dict[SS58Address, Node], + ) -> dict[int, float]: + """ + Compute weight distribution from miner scores. + + Args: + miner_scores: Mapping of miner hotkeys to their normalized scores + nodes: Mapping of hotkeys to Node objects with node_id + + Returns: + dict[int, float]: Mapping of UIDs to weights + """ + weights_dict: dict[int, float] = {} + + for hotkey, weight in miner_scores.items(): + node = nodes.get(hotkey) + if node: + weights_dict[node.node_id] = weight + + total_weight = sum(weights_dict.values()) + if total_weight > 1.0: + logger.warning( + "Total miner weight %.6f exceeds 1.0 before owner allocation", + total_weight, + ) + + owner_weight = max(0.0, 1.0 - total_weight) + if owner_weight > 0: + weights_dict[self.config.owner_uid] = ( + weights_dict.get(self.config.owner_uid, 0.0) + owner_weight + ) + if all(node.node_id != self.config.owner_uid for node in nodes.values()): + logger.warning( + "Owner UID %s not found in node list; assigning %.4f weight without hotkey mapping", + self.config.owner_uid, + owner_weight, + ) + else: + logger.info( + "Owner UID %s assigned remaining normalized score %.4f (burn_pct=%.2f%%)", + self.config.owner_uid, + owner_weight, + self.config.burn_pct * 100, + ) + + # Populate missing entries with 0.0 weight for all nodes + for node in nodes.values(): + weights_dict.setdefault(node.node_id, 0.0) + + return weights_dict diff --git a/src/backend/service.py b/src/backend/service.py index cad1619..702d10e 100644 --- a/src/backend/service.py +++ b/src/backend/service.py @@ -1,7 +1,10 @@ """ +Refactored Backend Service for Kinitro. -This provides REST API endpoints and WebSocket connections for: +This module provides the refactored BackendService using composition with +extracted components: ScoringEngine, ChainMonitor, JobScheduler, and WebSocketHub. +This provides REST API endpoints and WebSocket connections for: - Competition management - Validator connections - Job distribution @@ -9,20 +12,16 @@ """ import asyncio -import copy import hashlib import os from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Tuple import dotenv from asyncpg.exceptions import DeadlockDetectedError -from fastapi import ( - WebSocket, -) -from fiber.chain.fetch_nodes import _get_nodes_for_uid +from fastapi import WebSocket from fiber.chain.interface import get_substrate from fiber.chain.models import Node from snowflake import SnowflakeGenerator @@ -36,6 +35,8 @@ create_async_engine, ) +# Import extracted components +from backend.chain_monitor import ChainConfig, ChainMonitor from backend.constants import ( CHAIN_SCAN_YIELD_INTERVAL, DEFAULT_BURN_PCT, @@ -54,17 +55,25 @@ WEIGHT_BROADCAST_INTERVAL, WEIGHT_BROADCAST_STARTUP_DELAY, ) +from backend.evaluator_hub import EvaluatorHub from backend.events import ( EpisodeCompletedEvent, EpisodeStepEvent, JobCompletedEvent, - JobCreatedEvent, JobStatusChangedEvent, StatsUpdatedEvent, SubmissionReceivedEvent, ValidatorDisconnectedEvent, ) -from backend.realtime import event_broadcaster +from backend.job_scheduler import ( + EvaluationJobNotFoundError, + JobConfig, + JobScheduler, + NoBenchmarksAvailableError, + SubmissionNotFoundError, +) +from backend.realtime import EventType, event_broadcaster +from backend.scoring import ScoringConfig, ScoringEngine from backend.submission_storage import PresignedUpload, SubmissionStorage try: # pragma: no cover - optional dependency @@ -76,16 +85,13 @@ from substrateinterface import Keypair as SubstrateKeypair # type: ignore except ImportError: # pragma: no cover SubstrateKeypair = None -from core.chain import query_commitments_from_substrate + from core.db.models import EvaluationStatus from core.log import get_logger from core.messages import ( EpisodeDataMessage, EpisodeStepDataMessage, - EvalJobMessage, - EventType, MessageType, - SetWeightsMessage, ) from core.schemas import ChainCommitmentResponse, ModelProvider from core.storage import load_s3_config @@ -113,7 +119,7 @@ logger = get_logger(__name__) -ConnectionId = str # Unique ID for each WebSocket connection +ConnectionId = str VALIDATOR_MESSAGE_QUEUE_MAXSIZE = 5000 VALIDATOR_MESSAGE_BATCH_SIZE = 50 @@ -131,56 +137,6 @@ def _ensure_datetime(value: Any) -> datetime: raise ValueError(f"Unsupported datetime value: {value!r}") -def _extract_benchmark_spec_payload( - config: Mapping[str, Any], -) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Split a stored benchmark configuration into the full benchmark spec and the - underlying execution config used by the evaluator. - """ - spec_copy = copy.deepcopy(dict(config)) - try: - base_config_source = config["config"] - except KeyError as exc: # pragma: no cover - defensive guard - raise ValueError("Benchmark spec is missing 'config' payload") from exc - try: - base_config = copy.deepcopy(dict(base_config_source)) - except TypeError as exc: # pragma: no cover - defensive guard - raise ValueError("'config' payload must be a mapping") from exc - return spec_copy, base_config - - -def _normalize_benchmark_spec_payload( - provider: str, - benchmark_name: str, - payload: Mapping[str, Any] | None, -) -> dict[str, Any]: - """ - Ensure a benchmark specification payload includes top-level metadata and a nested config mapping. - - Accepts either the new-style payload (with a `config` key) or a bare config mapping and - returns a copy that always matches the new-style structure. - """ - if payload is None: - base_config: dict[str, Any] = {} - return { - "provider": provider, - "benchmark_name": benchmark_name, - "config": base_config, - } - - payload_dict = dict(payload) - if "config" in payload_dict: - return copy.deepcopy(payload_dict) - - base_config = copy.deepcopy(payload_dict) - return { - "provider": provider, - "benchmark_name": benchmark_name, - "config": base_config, - } - - class LeaderCandidateError(Exception): """Base exception for leader candidate operations.""" @@ -197,26 +153,22 @@ class LeaderCandidateNotApprovedError(LeaderCandidateError): """Raised when attempting to modify an unapproved leader candidate.""" -class SubmissionNotFoundError(Exception): - """Raised when a submission cannot be located.""" - - -class EvaluationJobNotFoundError(Exception): - """Raised when an evaluation job cannot be located.""" - - -class NoBenchmarksAvailableError(Exception): - """Raised when no benchmarks are available for an evaluation rerun.""" - - class BackendService: - """Core backend service logic.""" + """ + Core backend service logic using composition. + + This refactored service delegates to extracted components: + - ScoringEngine: Handles scoring, eligibility, and leader candidates + - ChainMonitor: Monitors blockchain for commitments + - JobScheduler: Creates and broadcasts evaluation jobs + - WebSocketHub: Manages validator WebSocket connections + """ def __init__(self, config: BackendConfig): self.config = config self.db_url = config.settings.get("database_url") - # Chain monitoring configuration + # Parse configuration values self.max_commitment_lookback = config.settings.get( "max_commitment_lookback", DEFAULT_MAX_COMMITMENT_LOOKBACK ) @@ -301,20 +253,10 @@ def __init__(self, config: BackendConfig): ) # Chain connection objects - # Using Any since fiber's SubstrateInterface is from async_substrate_interface - self.substrate: Optional[Any] = ( - None # async_substrate_interface.sync_substrate.SubstrateInterface - ) - self.nodes: Optional[Dict[SS58Address, Node]] = None - - # WebSocket connections - self.active_connections: Dict[ConnectionId, WebSocket] = {} - self.validator_connections: Dict[ConnectionId, SS58Address] = {} + self.substrate: Optional[Any] = None # Background tasks self._running = False - self._chain_monitor_task = None - self._heartbeat_monitor_task = None self._stale_job_monitor_task = None self._score_evaluation_task = None self._weight_broadcast_task = None @@ -349,12 +291,71 @@ def __init__(self, config: BackendConfig): ) self._validator_queue_warning_triggered = False + # Initialize extracted components (will be fully initialized in startup) + self.evaluator_hub = EvaluatorHub() # Direct evaluator connections + self.scoring_engine: Optional[ScoringEngine] = None + self.chain_monitor: Optional[ChainMonitor] = None + self.job_scheduler: Optional[JobScheduler] = None + + def _init_components(self) -> None: + """Initialize extracted components after database is ready.""" + # Create component configs + scoring_config = ScoringConfig( + owner_uid=self.owner_uid, + burn_pct=self.burn_pct, + ) + + chain_config = ChainConfig( + max_commitment_lookback=self.max_commitment_lookback, + chain_sync_interval=self.chain_sync_interval, + chain_scan_yield_interval=CHAIN_SCAN_YIELD_INTERVAL, + ) + + job_config = JobConfig( + default_job_timeout_seconds=self.default_job_timeout_seconds, + submission_download_url_ttl=self.submission_download_url_ttl, + ) + + # Initialize components + self.scoring_engine = ScoringEngine( + session_factory=self.async_session, + config=scoring_config, + id_generator=self.id_generator, + ) + + self.chain_monitor = ChainMonitor( + substrate=self.substrate, + backend_config=self.config, + session_factory=self.async_session, + config=chain_config, + thread_pool=self.thread_pool, + on_commitment=self._process_commitment, + ) + + self.job_scheduler = JobScheduler( + session_factory=self.async_session, + evaluator_hub=self.evaluator_hub, + config=job_config, + id_generator=self.id_generator, + submission_storage=self.submission_storage, + ) + + logger.info( + "Initialized extracted components: ScoringEngine, ChainMonitor, JobScheduler, EvaluatorHub" + ) + + @property + def nodes(self) -> Optional[Dict[SS58Address, Node]]: + """Delegate to ChainMonitor for backward compatibility.""" + if self.chain_monitor: + return self.chain_monitor.get_nodes() + return None + @staticmethod def verify_hotkey_signature( hotkey: str, message: bytes, signature_hex: str ) -> bool: """Verify a hotkey-signed payload using available sr25519 implementations.""" - signature_body = ( signature_hex[2:] if signature_hex.startswith("0x") else signature_hex ) @@ -379,7 +380,7 @@ def verify_hotkey_signature( keypair = keypair_cls(ss58_address=hotkey) if keypair.verify(message, signature): return True - except Exception as exc: # pragma: no cover - verification failure + except Exception as exc: logger.debug( "Signature verification attempt failed with %s: %s", getattr(keypair_cls, "__name__", str(keypair_cls)), @@ -429,7 +430,6 @@ def _job_timeout_seconds(self, competition: Optional[Competition]) -> int: def _resolve_validator_worker_count(self, configured_value: Any) -> int: """Determine validator worker pool size from config or CPU count.""" - default_workers = self._default_validator_worker_count if configured_value is None: @@ -471,7 +471,6 @@ async def create_submission_upload( artifact_size_bytes: int, ) -> tuple[SubmissionUpload, PresignedUpload]: """Create a submission upload record and mint an upload URL.""" - if not self.submission_storage: raise RuntimeError("Submission storage is not configured") if not self.async_session: @@ -580,7 +579,7 @@ async def create_submission_upload( async def startup(self) -> None: """Initialize the backend service without starting background tasks.""" - logger.info("Initializing Kinitro Backend Service") + logger.info("Initializing Kinitro Backend Service (Refactored)") # Initialize database first await self._init_database() @@ -588,6 +587,9 @@ async def startup(self) -> None: # Initialize chain connection await self._init_chain() + # Initialize extracted components + self._init_components() + # Load backend state await self._load_backend_state() @@ -598,11 +600,9 @@ async def start_background_tasks(self) -> None: """Start background tasks after FastAPI is ready.""" logger.info("Starting background tasks") - # Start core monitoring tasks first - self._chain_monitor_task = asyncio.create_task(self._monitor_chain()) - self._heartbeat_monitor_task = asyncio.create_task( - self._monitor_validator_heartbeats() - ) + # Start chain monitor using the extracted component + await self.chain_monitor.start() + self._stale_job_monitor_task = asyncio.create_task(self._monitor_stale_jobs()) for worker_id in range(self.validator_worker_count): @@ -658,10 +658,12 @@ async def shutdown(self) -> None: self._running = False + # Stop chain monitor + if self.chain_monitor: + await self.chain_monitor.stop() + # Cancel background tasks tasks_to_cancel = [ - (self._chain_monitor_task, "chain_monitor"), - (self._heartbeat_monitor_task, "heartbeat_monitor"), (self._stale_job_monitor_task, "stale_job_monitor"), (self._score_evaluation_task, "score_evaluation"), (self._weight_broadcast_task, "weight_broadcast"), @@ -682,10 +684,6 @@ async def shutdown(self) -> None: self._validator_worker_tasks.clear() - # Close WebSocket connections - for ws in self.active_connections.values(): - await ws.close() - # Close database if self.engine: await self.engine.dispose() @@ -705,12 +703,6 @@ async def _init_chain(self) -> None: subtensor_address=self.config.settings["subtensor"]["address"], ) - node_list = _get_nodes_for_uid( - self.substrate, self.config.settings["subtensor"]["netuid"] - ) - - self.nodes = {node.hotkey: node for node in node_list} - logger.info("Blockchain connection initialized") except Exception as e: logger.error(f"Failed to initialize blockchain connection: {e}") @@ -729,13 +721,7 @@ async def _init_database(self) -> None: logger.info("Database connection initialized") async def _load_backend_state(self) -> None: - """Load or initialize backend service state. - - This loads the singleton BackendState record which tracks: - - Chain monitoring state (last seen block number, last chain scan time) - - Service metadata (version, start time) - - Persistence across service restarts - """ + """Load or initialize backend service state.""" if not self.async_session: logger.error("Database not initialized") return @@ -757,1712 +743,667 @@ async def _load_backend_state(self) -> None: f"Loaded backend state: last_seen_block={state.last_seen_block}" ) - async def _monitor_chain(self) -> None: - """Background task to monitor blockchain for commitments.""" - while self._running: - try: - if self.substrate and self.nodes and self.async_session: - await self._sync_metagraph() - - async with self.async_session() as session: - # Get backend state - state_result = await session.execute( - select(BackendState).where(BackendState.id == 1) - ) - state = state_result.scalar_one() + # Scoring methods - delegate to ScoringEngine + def _is_miner_eligible( + self, + result: BackendEvaluationResult, + competition: Competition, + ) -> bool: + """Delegate eligibility check to ScoringEngine.""" + return self.scoring_engine.is_eligible(result, competition) - # Get latest block - latest_block = await self._get_latest_block() - start_block = max( - state.last_seen_block + 1, - latest_block - self.max_commitment_lookback + 1, - ) + async def _queue_leader_candidate( + self, + session: AsyncSession, + competition: Competition, + result: BackendEvaluationResult, + ) -> bool: + """Delegate leader candidate queuing to ScoringEngine.""" + return await self.scoring_engine.queue_leader_candidate( + session, competition, result + ) - logger.info(f"Checking blocks {start_block} to {latest_block}") + async def _score_evaluations(self) -> dict[SS58Address, float]: + """Delegate scoring to ScoringEngine.""" + return await self.scoring_engine.score_evaluations() - # Get active competitions - comp_result = await session.execute( - select(Competition).where(Competition.active) - ) + async def _periodic_score_evaluation(self) -> None: + """Periodically evaluate and update miner scores.""" + logger.info("Starting periodic score evaluation task") - active_competitions = {c.id: c for c in comp_result.scalars()} - # preview active competitions - logger.debug( - f"Preview of active competitions: {list(active_competitions.keys())[:5]}" - ) + while self._running: + try: + logger.info("Running score evaluation cycle") + miner_scores = await self._score_evaluations() - # Query commitments (with yield points to prevent blocking) - for i, block_num in enumerate( - range(start_block, latest_block + 1) - ): - commitments = await self._query_block_commitments(block_num) - for commitment in commitments: - await self._process_commitment( - commitment, block_num, active_competitions - ) + # Store latest scores for weight broadcasting + self._latest_miner_scores = miner_scores - # Yield control periodically to prevent blocking WebSocket connections - if i % CHAIN_SCAN_YIELD_INTERVAL == 0: - await asyncio.sleep(0) + logger.info( + "Score evaluation complete. %s miners scored.", + len(miner_scores), + ) - # Update state - state.last_seen_block = latest_block - state.last_chain_scan = datetime.now(timezone.utc) - await session.commit() + except Exception as e: + logger.error(f"Error in periodic score evaluation: {e}") - await asyncio.sleep(self.chain_sync_interval) + await asyncio.sleep(self.score_evaluation_interval) - except Exception as e: - logger.error(f"Error monitoring chain: {e}") - await asyncio.sleep(self.chain_sync_interval) + async def _periodic_weight_broadcast(self) -> None: + """Periodically broadcast weights to validators and set on chain.""" + logger.info("Starting periodic weight broadcast task") - async def _monitor_stale_jobs(self) -> None: - """Monitor for stale jobs and mark them as failed.""" while self._running: try: - if self.async_session: - async with self.async_session() as session: - current_time = datetime.now(timezone.utc) - # Prefilter jobs older than the minimum configured timeout; exact check happens per job below. - min_timeout_result = await session.execute( - select(func.min(Competition.job_timeout_seconds)) - ) - min_timeout = min_timeout_result.scalar() - candidates = [ - t - for t in ( - min_timeout, - self.default_job_timeout_seconds, - ) - if t and t > 0 - ] - min_timeout_seconds = min(candidates) if candidates else 1 - prefilter_threshold = current_time - timedelta( - seconds=min_timeout_seconds - ) - prefilter_threshold = prefilter_threshold.replace(tzinfo=None) - - result = await session.execute( - select(BackendEvaluationJob).where( - BackendEvaluationJob.created_at < prefilter_threshold - ) - ) + logger.info("Running weight broadcast cycle") + await self._broadcast_and_set_weights() - candidate_jobs = result.scalars().all() + except Exception as e: + logger.error(f"Error in periodic weight broadcast: {e}") - for job in candidate_jobs: - job_timeout = self._resolve_job_timeout_seconds( - job.timeout_seconds or self.default_job_timeout_seconds - ) - stale_threshold = current_time - timedelta( - seconds=job_timeout - ) + await asyncio.sleep(self.weight_broadcast_interval) - if job.created_at and job.created_at >= stale_threshold: - continue + async def _broadcast_and_set_weights(self) -> None: + """Broadcast weights to connected validators using extracted components.""" + try: + if not self.substrate: + logger.error("Substrate not initialized") + return - # Check if this job has any recent status updates - status_result = await session.execute( - select(BackendEvaluationJobStatus) - .where(BackendEvaluationJobStatus.job_id == job.id) - .order_by(BackendEvaluationJobStatus.created_at.desc()) - .limit(1) - ) - latest_status = status_result.scalar() + nodes = self.chain_monitor.get_nodes() if self.chain_monitor else None + if not nodes: + logger.error("Node list not initialized") + return - # If no status or last status is not terminal, mark as failed - if not latest_status or latest_status.status not in [ - EvaluationStatus.COMPLETED, - EvaluationStatus.FAILED, - EvaluationStatus.CANCELLED, - EvaluationStatus.TIMEOUT, - ]: - logger.warning(f"Marking stale job {job.id} as TIMEOUT") + # Use cached scores from periodic evaluation + miner_scores = self._latest_miner_scores.copy() - # Create timeout status for all connected validators - for ( - validator_hotkey - ) in self.validator_connections.values(): - timeout_status = BackendEvaluationJobStatus( - id=next(self.id_generator), - job_id=job.id, - validator_hotkey=validator_hotkey, - status=EvaluationStatus.TIMEOUT, - detail="Job marked as timeout due to inactivity", - ) - session.add(timeout_status) + # Compute weights using ScoringEngine + weights_dict = self.scoring_engine.compute_weights(miner_scores, nodes) - await session.commit() + if not weights_dict: + logger.info("No miner scores to broadcast") + return - # Broadcast timeout event - await event_broadcaster.broadcast_event( - EventType.JOB_STATUS_CHANGED, - { - "job_id": str(job.id), - "status": "TIMEOUT", - "detail": "Job marked as timeout due to inactivity", - }, - ) + snapshot_total_weight = float(sum(weights_dict.values())) + self._latest_weights_snapshot = WeightsSnapshot( + updated_at=datetime.now(timezone.utc), + total_weight=snapshot_total_weight, + weights=weights_dict.copy(), + ) - await asyncio.sleep(300) # Check every 5 minutes + logger.info( + "Updated weights snapshot: %d miners, total_weight=%.4f", + len(weights_dict), + snapshot_total_weight, + ) + except Exception as e: + logger.error(f"Failed to update weights: {e}") - except Exception as e: - logger.error(f"Error monitoring stale jobs: {e}") - await asyncio.sleep(300) + def get_latest_weights_snapshot(self) -> Optional[WeightsSnapshot]: + """Return the most recent weight broadcast snapshot, if available.""" + snapshot = self._latest_weights_snapshot + if not snapshot: + return None + return snapshot.model_copy(deep=True) - async def queue_validator_heartbeat( - self, validator_hotkey: SS58Address, timestamp: datetime + # Commitment processing callback for ChainMonitor + async def _process_commitment( + self, + commitment: ChainCommitmentResponse, + block_num: int, + active_competitions: Dict[str, Competition], ) -> None: - """Queue a validator heartbeat for asynchronous persistence.""" + """Process a commitment from the chain.""" + try: + logger.debug( + "Processing commitment for block %s: %s", block_num, commitment + ) + competition_id = getattr(commitment.data, "comp_id", None) - enqueued = await self._enqueue_validator_message( - { - "type": "heartbeat", - "validator_hotkey": validator_hotkey, - "timestamp": timestamp, - }, - block=False, - ) + if not competition_id or competition_id not in active_competitions: + logger.warning( + "Miner %s submitted commitment for unknown competition %s", + commitment.hotkey, + competition_id, + ) + return + + provider = getattr(commitment.data, "provider", None) + + if provider == ModelProvider.S3: + await self._process_s3_commitment( + commitment, + block_num, + active_competitions[competition_id], + ) + return - if not enqueued: logger.warning( - "Dropping heartbeat for %s due to full validator queue", - validator_hotkey, + "Unsupported commitment provider %s from miner %s", + provider, + commitment.hotkey, ) + except Exception as exc: + logger.error("Failed to process commitment: %s", exc) - async def queue_episode_data( - self, validator_hotkey: SS58Address, message: EpisodeDataMessage + async def _process_s3_commitment( + self, + commitment: ChainCommitmentResponse, + block_num: int, + competition: Competition, ) -> None: - """Queue episode data for asynchronous persistence.""" - - enqueued = await self._enqueue_validator_message( - { - "type": "episode_data", - "validator_hotkey": validator_hotkey, - "message": message, - } - ) - - if not enqueued: - logger.error( - "Failed to enqueue episode data for submission %s episode %s", - message.submission_id, - message.episode_id, - ) - - async def queue_episode_step_data( - self, validator_hotkey: SS58Address, message: EpisodeStepDataMessage - ) -> None: - """Queue episode step data for asynchronous persistence.""" - - enqueued = await self._enqueue_validator_message( - { - "type": "episode_step_data", - "validator_hotkey": validator_hotkey, - "message": message, - } - ) - - if not enqueued: - logger.error( - "Failed to enqueue step data for submission %s episode %s step %s", - message.submission_id, - message.episode_id, - message.step, - ) - - async def _enqueue_validator_message( - self, payload: Dict[str, Any], *, block: bool = True - ) -> bool: - """Enqueue a validator message, optionally dropping if queue is saturated.""" - - if not self._running: - logger.debug( - "Received validator message %s while backend not running", - payload.get("type"), - ) - return False - - queue_size = self._validator_message_queue.qsize() - if queue_size > self._validator_queue_warning_threshold: - if not self._validator_queue_warning_triggered: - logger.warning( - "Validator message queue high water mark: %s/%s", - queue_size, - VALIDATOR_MESSAGE_QUEUE_MAXSIZE, - ) - self._validator_queue_warning_triggered = True - elif self._validator_queue_warning_triggered and queue_size < ( - self._validator_queue_warning_threshold // 2 - ): - logger.info( - "Validator message queue draining (current size %s)", queue_size - ) - self._validator_queue_warning_triggered = False - - try: - if block: - await self._validator_message_queue.put(payload) - else: - self._validator_message_queue.put_nowait(payload) - return True - except asyncio.QueueFull: - logger.error( - "Validator message queue full; dropping %s message", - payload.get("type"), - ) - return False - - async def _validator_message_worker(self, worker_id: int) -> None: - """Background worker that batches validator messages before persisting.""" - - logger.info("Validator message worker %s started", worker_id) - try: - while self._running: - try: - message = await self._validator_message_queue.get() - except asyncio.CancelledError: - break - - batch = [message] - loop = asyncio.get_running_loop() - deadline = loop.time() + VALIDATOR_MESSAGE_BATCH_INTERVAL - - while len(batch) < VALIDATOR_MESSAGE_BATCH_SIZE: - timeout = deadline - loop.time() - if timeout <= 0: - break - - try: - next_message = await asyncio.wait_for( - self._validator_message_queue.get(), timeout - ) - batch.append(next_message) - except asyncio.TimeoutError: - break - except asyncio.CancelledError: - raise - - try: - await self._process_validator_batch(batch) - except Exception as exc: # pragma: no cover - best effort logging - logger.error( - "Validator message worker %s failed to process batch: %s", - worker_id, - exc, - ) - finally: - for _ in batch: - self._validator_message_queue.task_done() - - finally: - logger.info("Validator message worker %s stopped", worker_id) - - async def _process_validator_batch(self, batch: List[Dict[str, Any]]) -> None: - """Persist a batch of validator messages and emit related events.""" - - if not batch: - return - - if not self.async_session: - logger.error("Database not initialized; dropping validator messages") - return - - heartbeats, episode_payload_map, step_payload_map = ( - self._group_validator_messages(batch) - ) - - step_payload_map = { - key: sorted(payloads, key=lambda payload: payload["message"].step) - for key, payloads in step_payload_map.items() - } - - all_episode_keys = sorted( - set(episode_payload_map.keys()) | set(step_payload_map.keys()), - key=lambda k: (k[0], k[2], k[1], k[3]), - ) - - step_table = EpisodeStepData.__table__ - placeholder_warned: set[EpisodeKey] = set() - placeholder_keys: set[EpisodeKey] = set() - - max_retries = 5 - attempt = 0 - - while True: - try: - pending_events: List[tuple[EventType, Any]] = [] - status_updates: List[Dict[str, Any]] = [] - status_update_keys: set[tuple[Any, SS58Address]] = set() - - async with self.async_session() as session: - episode_lookup: Dict[EpisodeKey, int] = {} - - await self._apply_heartbeats(session, heartbeats) - - for key in all_episode_keys: - episode_payload = episode_payload_map.get(key) - step_payloads = step_payload_map.get(key, []) - if not episode_payload and not step_payloads: - continue - - await self._acquire_episode_lock(session, key) - - current_episode_id = episode_lookup.get(key) - if current_episode_id is None and step_payloads: - existing_episode_id = await self._get_episode_id_from_db( - session, key - ) - if existing_episode_id is not None: - episode_lookup[key] = existing_episode_id - current_episode_id = existing_episode_id - - if episode_payload: - message: EpisodeDataMessage = episode_payload["message"] - validator_hotkey: SS58Address = episode_payload[ - "validator_hotkey" - ] - logger.info( - "Applying episode summary submission=%s task=%s episode=%s validator=%s reward=%s steps=%s", - message.submission_id, - message.task_id, - message.episode_id, - validator_hotkey, - message.final_reward, - message.steps, - ) - current_episode_id = await self._ensure_episode_record( - message, - validator_hotkey, - episode_lookup, - session, - ) - if key in placeholder_keys: - logger.info( - "Overwriting placeholder episode with summary submission=%s task=%s episode=%s validator=%s", - message.submission_id, - message.task_id, - message.episode_id, - validator_hotkey, - ) - placeholder_keys.discard(key) - - pending_events.append( - ( - EventType.EPISODE_COMPLETED, - EpisodeCompletedEvent( - job_id=message.job_id, - submission_id=message.submission_id, - validator_hotkey=validator_hotkey, - episode_id=message.episode_id, - env_name=message.env_name, - benchmark_name=message.benchmark_name, - final_reward=message.final_reward, - success=message.success, - steps=message.steps, - start_time=_ensure_datetime(message.start_time), - end_time=_ensure_datetime(message.end_time), - extra_metrics=message.extra_metrics, - created_at=datetime.now(timezone.utc), - ), - ), - ) - - status_result = await session.execute( - select(BackendEvaluationJobStatus).where( - BackendEvaluationJobStatus.job_id == message.job_id, - BackendEvaluationJobStatus.validator_hotkey - == validator_hotkey, - BackendEvaluationJobStatus.status - == EvaluationStatus.RUNNING, - ) - ) - if not status_result.scalar_one_or_none(): - status_key = (message.job_id, validator_hotkey) - if status_key not in status_update_keys: - status_updates.append( - { - "job_id": message.job_id, - "validator_hotkey": validator_hotkey, - "status": EvaluationStatus.RUNNING, - "detail": f"Started processing episodes (episode {message.episode_id})", - } - ) - status_update_keys.add(status_key) - - if not episode_payload and step_payloads: - if key not in placeholder_warned: - submission_id, episode_no, task_id, validator_key = key - logger.warning( - "Episode summary missing for submission=%s task=%s episode=%s validator=%s; using placeholder values", - submission_id, - task_id, - episode_no, - validator_key, - ) - placeholder_warned.add(key) - - if current_episode_id is None and step_payloads: - placeholder_episode = self._placeholder_episode_from_step( - step_payloads[0]["message"] - ) - logger.info( - "Creating placeholder episode for submission=%s task=%s episode=%s validator=%s", - placeholder_episode.submission_id, - placeholder_episode.task_id, - placeholder_episode.episode_id, - step_payloads[0]["validator_hotkey"], - ) - placeholder_keys.add(key) - current_episode_id = await self._ensure_episode_record( - placeholder_episode, - step_payloads[0]["validator_hotkey"], - episode_lookup, - session, - ) - - for step_payload in step_payloads: - step_message: EpisodeStepDataMessage = step_payload[ - "message" - ] - validator_hotkey = step_payload["validator_hotkey"] - - episode_lookup_id = episode_lookup.get(key) - if episode_lookup_id is None: - episode_lookup_id = await self._get_episode_id_from_db( - session, key - ) - if episode_lookup_id is not None: - episode_lookup[key] = episode_lookup_id - if episode_lookup_id is None: - placeholder_episode = ( - self._placeholder_episode_from_step(step_message) - ) - logger.info( - "Creating placeholder episode for submission=%s task=%s episode=%s validator=%s", - placeholder_episode.submission_id, - placeholder_episode.task_id, - placeholder_episode.episode_id, - validator_hotkey, - ) - placeholder_keys.add(key) - episode_lookup_id = await self._ensure_episode_record( - placeholder_episode, - validator_hotkey, - episode_lookup, - session, - ) - - step_values = { - "id": next(self.id_generator), - "episode_id": episode_lookup_id, - "submission_id": step_message.submission_id, - "validator_hotkey": validator_hotkey, - "task_id": step_message.task_id, - "step": step_message.step, - "action": step_message.action, - "reward": step_message.reward, - "done": step_message.done, - "truncated": step_message.truncated, - "observation_refs": step_message.observation_refs, - "info": step_message.info, - "timestamp": _ensure_datetime( - step_message.step_timestamp - ), - } - - step_insert = ( - insert(step_table) - .values(**step_values) - .on_conflict_do_update( - index_elements=["episode_id", "step"], - set_={ - "submission_id": step_values["submission_id"], - "validator_hotkey": step_values[ - "validator_hotkey" - ], - "task_id": step_values["task_id"], - "action": step_values["action"], - "reward": step_values["reward"], - "done": step_values["done"], - "truncated": step_values["truncated"], - "observation_refs": step_values[ - "observation_refs" - ], - "info": step_values["info"], - "timestamp": step_values["timestamp"], - "updated_at": func.now(), - }, - ) - ) - - await session.execute(step_insert) - - pending_events.append( - ( - EventType.EPISODE_STEP, - EpisodeStepEvent( - submission_id=step_message.submission_id, - validator_hotkey=validator_hotkey, - episode_id=step_message.episode_id, - step=step_message.step, - action=step_message.action, - reward=step_message.reward, - done=step_message.done, - truncated=step_message.truncated, - observation_refs=step_message.observation_refs, - info=step_message.info, - ), - ) - ) - - if episode_payload is None and step_payloads: - submission_id, episode_no, task_id, validator_key = key - if episode_lookup.get(key) is None: - logger.warning( - "Episode summary missing for submission=%s task=%s episode=%s validator=%s; using placeholder values", - submission_id, - task_id, - episode_no, - validator_key, - ) - - await session.commit() - - for update in status_updates: - await self._update_job_status( - update["job_id"], - update["validator_hotkey"], - update["status"], - update["detail"], - ) - - for event_type, event_payload in pending_events: - await event_broadcaster.broadcast_event(event_type, event_payload) - - if heartbeats: - logger.debug("Processed %s heartbeat updates", len(heartbeats)) - if episode_payload_map: - logger.info( - "Persisted %s episode records", len(episode_payload_map) - ) - if step_payload_map: - logger.debug( - "Persisted %s episode step records", - sum(len(payloads) for payloads in step_payload_map.values()), - ) - - break - except DBAPIError as exc: - if ( - isinstance(exc.orig, DeadlockDetectedError) - and attempt < max_retries - ): - attempt += 1 - delay = min(0.5 * attempt, 3.0) - logger.warning( - "Deadlock detected while processing validator batch (attempt %s/%s); retrying in %.2fs", - attempt, - max_retries, - delay, - ) - await asyncio.sleep(delay) - continue - raise - - async def _monitor_validator_heartbeats(self) -> None: - """Monitor validator heartbeats and cleanup stale connections.""" - while self._running: - try: - current_time = datetime.now(timezone.utc) - timeout_threshold = current_time - timedelta(minutes=2) - - if self.async_session: - async with self.async_session() as session: - # Find stale validators - result = await session.execute( - select(ValidatorConnection).where( - and_( - ValidatorConnection.is_connected, - ValidatorConnection.last_heartbeat - < timeout_threshold, - ) - ) - ) - - stale_validators = result.scalars().all() - - for validator in stale_validators: - logger.warning( - f"Marking validator as disconnected: {validator.validator_hotkey}" - ) - validator.is_connected = False - - # Close WebSocket if exists - conn_id_to_remove = None - for conn_id, hotkey in list( - self.validator_connections.items() - ): - if hotkey == validator.validator_hotkey: - if conn_id in self.active_connections: - await self.active_connections[conn_id].close() - del self.active_connections[conn_id] - del self.validator_connections[conn_id] - conn_id_to_remove = conn_id - break - - # Broadcast validator disconnected event - if conn_id_to_remove: - disconnected_event = ValidatorDisconnectedEvent( - validator_hotkey=validator.validator_hotkey, - connection_id=conn_id_to_remove, - disconnected_at=datetime.now(timezone.utc), - reason="Heartbeat timeout", - ) - await event_broadcaster.broadcast_event( - EventType.VALIDATOR_DISCONNECTED, disconnected_event - ) - - await session.commit() - - # Broadcast updated stats if any validators were disconnected - if stale_validators: - await self._broadcast_stats_update() - - await asyncio.sleep(30) - - except Exception as e: - logger.error(f"Error in heartbeat monitor: {e}") - await asyncio.sleep(HEARTBEAT_INTERVAL.total_seconds()) - - def _is_miner_eligible( - self, - result: BackendEvaluationResult, - competition: Competition, - ) -> bool: - """Check if a miner meets eligibility criteria for a competition.""" - if result.success_rate is None or result.avg_reward is None: - return False - - if result.success_rate < competition.min_success_rate: - logger.trace( - f"Miner {result.miner_hotkey} excluded from competition {competition.id}: " - f"success_rate={result.success_rate:.3f} < min_threshold={competition.min_success_rate:.3f}" - ) - return False - - if result.avg_reward < competition.min_avg_reward: - logger.trace( - f"Miner {result.miner_hotkey} excluded from competition {competition.id}: " - f"avg_reward={result.avg_reward:.3f} < min_threshold={competition.min_avg_reward}" - ) - return False - - return True - - async def _queue_leader_candidate( - self, - session: AsyncSession, - competition: Competition, - result: BackendEvaluationResult, - ) -> bool: - """Persist a leader candidate if not already recorded for this result.""" - - if result.avg_reward is None: - logger.debug( - "Skipping leader candidate creation without avg_reward: competition=%s result_id=%s", - competition.id, - result.id, - ) - return False - - existing_candidate_result = await session.execute( - select(CompetitionLeaderCandidate).where( - CompetitionLeaderCandidate.evaluation_result_id == result.id - ) - ) - existing_candidate = existing_candidate_result.scalar_one_or_none() - if existing_candidate: - logger.debug( - "Leader candidate already exists for evaluation result %s (competition=%s)", - result.id, - competition.id, - ) - return False - - candidate = CompetitionLeaderCandidate( - id=next(self.id_generator), - competition_id=competition.id, - miner_hotkey=result.miner_hotkey, - evaluation_result_id=result.id, - avg_reward=result.avg_reward, - success_rate=result.success_rate, - score=result.score, - total_episodes=result.total_episodes, - ) - session.add(candidate) - return True - - async def approve_leader_candidate( - self, - candidate_id: int, - admin_api_key_id: int, - reason: Optional[str] = None, - ) -> CompetitionLeaderCandidate: - """Approve a pending leader candidate and promote them to current leader.""" - - if not self.async_session: - raise RuntimeError("Database not initialized") - - async with self.async_session() as session: - candidate = await session.get(CompetitionLeaderCandidate, candidate_id) - if not candidate: - raise LeaderCandidateNotFoundError( - f"Leader candidate {candidate_id} not found" - ) - - if candidate.status != LeaderCandidateStatus.PENDING: - raise LeaderCandidateAlreadyReviewedError( - f"Leader candidate {candidate_id} has already been reviewed" - ) - - competition = await session.get(Competition, candidate.competition_id) - if not competition: - raise LeaderCandidateNotFoundError( - f"Competition {candidate.competition_id} not found for candidate" - ) - - now = datetime.now(timezone.utc) - - candidate.status = LeaderCandidateStatus.APPROVED - candidate.status_reason = reason - candidate.reviewed_by_api_key_id = admin_api_key_id - candidate.reviewed_at = now - - competition.current_leader_hotkey = candidate.miner_hotkey - competition.current_leader_reward = candidate.avg_reward - competition.leader_updated_at = now - - await session.commit() - await session.refresh(candidate) - - try: - await self._broadcast_stats_update() - except Exception as exc: - logger.error( - "Failed to broadcast stats after leader candidate approval: %s", - exc, - ) - - return candidate - - async def reject_leader_candidate( - self, - candidate_id: int, - admin_api_key_id: int, - reason: Optional[str] = None, - ) -> CompetitionLeaderCandidate: - """Reject a pending leader candidate.""" - - if not self.async_session: - raise RuntimeError("Database not initialized") - - async with self.async_session() as session: - candidate = await session.get(CompetitionLeaderCandidate, candidate_id) - if not candidate: - raise LeaderCandidateNotFoundError( - f"Leader candidate {candidate_id} not found" - ) - - if candidate.status != LeaderCandidateStatus.PENDING: - raise LeaderCandidateAlreadyReviewedError( - f"Leader candidate {candidate_id} has already been reviewed" - ) - - now = datetime.now(timezone.utc) - - candidate.status = LeaderCandidateStatus.REJECTED - candidate.status_reason = reason - candidate.reviewed_by_api_key_id = admin_api_key_id - candidate.reviewed_at = now - - await session.commit() - await session.refresh(candidate) - - try: - await self._broadcast_stats_update() - except Exception as exc: - logger.error( - "Failed to broadcast stats after leader candidate rejection: %s", - exc, - ) - - return candidate - - async def unapprove_leader_candidate( - self, - candidate_id: int, - admin_api_key_id: int, - reason: Optional[str] = None, - ) -> CompetitionLeaderCandidate: - """Revert an approved leader candidate back to pending state.""" + """Handle commitments referencing direct-vault submissions.""" + if not self.submission_storage: + logger.error( + "Submission storage not configured; cannot process S3 commitment" + ) + return if not self.async_session: - raise RuntimeError("Database not initialized") - - async with self.async_session() as session: - candidate = await session.get(CompetitionLeaderCandidate, candidate_id) - if not candidate: - raise LeaderCandidateNotFoundError( - f"Leader candidate {candidate_id} not found" - ) - - if candidate.status != LeaderCandidateStatus.APPROVED: - raise LeaderCandidateNotApprovedError( - f"Leader candidate {candidate_id} is not approved" - ) - - competition = await session.get(Competition, candidate.competition_id) - if not competition: - raise LeaderCandidateNotFoundError( - f"Competition {candidate.competition_id} not found for candidate" - ) - - previous_reviewed_at = candidate.reviewed_at - was_current_leader = ( - competition.current_leader_hotkey == candidate.miner_hotkey - and competition.leader_updated_at == previous_reviewed_at - ) - - candidate.status = LeaderCandidateStatus.PENDING - candidate.status_reason = reason - candidate.reviewed_by_api_key_id = None - candidate.reviewed_at = None - await session.flush() - - if was_current_leader: - fallback_stmt = ( - select(CompetitionLeaderCandidate) - .where( - CompetitionLeaderCandidate.competition_id - == candidate.competition_id, - CompetitionLeaderCandidate.status - == LeaderCandidateStatus.APPROVED, - CompetitionLeaderCandidate.id != candidate.id, - ) - .order_by(CompetitionLeaderCandidate.reviewed_at.desc()) - .limit(1) - ) - fallback_result = await session.execute(fallback_stmt) - fallback_candidate = fallback_result.scalar_one_or_none() - - if fallback_candidate: - competition.current_leader_hotkey = fallback_candidate.miner_hotkey - competition.current_leader_reward = fallback_candidate.avg_reward - competition.leader_updated_at = fallback_candidate.reviewed_at - else: - competition.current_leader_hotkey = None - competition.current_leader_reward = None - competition.leader_updated_at = None - - await session.commit() - await session.refresh(candidate) - logger.info( - "Admin %s unapproved leader candidate %s (competition=%s)", - admin_api_key_id, - candidate_id, - candidate.competition_id, - ) + logger.error("Database not initialized; cannot process commitment") + return try: - await self._broadcast_stats_update() - except Exception as exc: + submission_id = int(commitment.data.repo_id) + except (TypeError, ValueError): logger.error( - "Failed to broadcast stats after leader candidate unapproval: %s", - exc, + "Invalid submission identifier '%s' provided by miner %s", + commitment.data.repo_id, + commitment.hotkey, ) + return - return candidate - - async def _score_evaluations(self) -> dict[SS58Address, float]: - """ - Score completed evaluations with winner-takes-all per competition. - - Scoring logic: - - Miners must meet minimum success rate threshold per competition to be considered - - Miners must pass minimum avg reward threshold per competition - - Eligible challengers above the approved leader's success rate (ordered by success_rate, then avg_reward) are queued for admin review - - If the current leader improves or matches their approved success rate with a new result, that result is also queued - - Current leader retains position until admin approval - - Each miner can only win ONE competition (first-win policy if appearing in multiple) - - Final scores are normalized based on competition points - - Returns: - dict[SS58Address, float]: Mapping of miner hotkeys to their normalized scores (0-1). - """ - # TODO: consider eval results from multiple (minimum 2) validators before applying scores? async with self.async_session() as session: - # Fetch all active competitions - competitions_result = await session.execute( - select(Competition).where(Competition.active) + upload_result = await session.execute( + select(SubmissionUpload).where( + SubmissionUpload.submission_id == submission_id + ) ) - competitions = competitions_result.scalars().all() - - if not competitions: - logger.info("No active competitions found for scoring") - return {} - - # Calculate total points across all competitions - total_points = sum(comp.points for comp in competitions) - - # Dictionary to store winner scores - miner_scores: dict[SS58Address, float] = {} + upload: Optional[SubmissionUpload] = upload_result.scalar_one_or_none() - for competition in competitions: - # Get all evaluation results for this competition - results_query = select(BackendEvaluationResult).where( - BackendEvaluationResult.competition_id == competition.id + if not upload: + logger.error( + "No pending upload found for submission %s (miner %s)", + submission_id, + commitment.hotkey, ) - results = await session.execute(results_query) - eval_results = results.scalars().all() - - if not eval_results: - logger.debug( - f"No evaluation results for competition {competition.id}" - ) - continue + return - # Find eligible challengers and order them by success rate / avg reward - eligible_results: List[BackendEvaluationResult] = [ - result - for result in eval_results - if self._is_miner_eligible(result, competition) - ] - - eligible_results.sort( - key=lambda res: ( - res.success_rate - if res.success_rate is not None - else float("-inf"), - res.avg_reward if res.avg_reward is not None else float("-inf"), - ), - reverse=True, + if upload.status == SubmissionUploadStatus.PROCESSED: + logger.info( + "Submission %s already processed; ignoring duplicate commitment", + submission_id, ) + return - if not eligible_results: - if competition.current_leader_hotkey: - logger.info( - "Competition %s: Current leader %s retains position (no eligible challengers)", - competition.id, - competition.current_leader_hotkey, - ) - else: - logger.info( - "Competition %s: No eligible miners found", competition.id - ) - continue - - current_leader = competition.current_leader_hotkey - - if current_leader is None: - queued_any = False - for res in eligible_results: - created_candidate = await self._queue_leader_candidate( - session, competition, res - ) - if created_candidate: - queued_any = True - logger.info( - "Competition %s: Queued leader candidate %s (success_rate=%.3f, avg_reward=%.3f)", - competition.id, - res.miner_hotkey, - res.success_rate or 0.0, - res.avg_reward or 0.0, - ) - if not queued_any: - logger.debug( - "Competition %s: All eligible results already queued as candidates", - competition.id, - ) - else: - leader_success_rate_stmt = ( - select( - CompetitionLeaderCandidate.success_rate, - CompetitionLeaderCandidate.evaluation_result_id, - ) - .where( - CompetitionLeaderCandidate.competition_id == competition.id, - CompetitionLeaderCandidate.miner_hotkey == current_leader, - CompetitionLeaderCandidate.status - == LeaderCandidateStatus.APPROVED, - ) - .order_by( - CompetitionLeaderCandidate.reviewed_at.desc(), - CompetitionLeaderCandidate.updated_at.desc(), - ) - .limit(1) - ) - leader_success_rate_result = await session.execute( - leader_success_rate_stmt - ) - leader_success_rate_row = leader_success_rate_result.first() - leader_success_rate = ( - leader_success_rate_row[0] if leader_success_rate_row else None - ) - leader_success_eval_id = ( - leader_success_rate_row[1] if leader_success_rate_row else None - ) - baseline_leader_success_rate = ( - leader_success_rate if leader_success_rate is not None else -1.0 - ) - - leader_best = next( - ( - res - for res in eligible_results - if res.miner_hotkey == current_leader - ), - None, - ) - if ( - leader_best - and leader_best.avg_reward is not None - and leader_best.avg_reward != competition.current_leader_reward - ): - competition.current_leader_reward = leader_best.avg_reward - competition.leader_updated_at = datetime.now(timezone.utc) - logger.info( - "Competition %s: Updated leader %s reward to %.3f", - competition.id, - current_leader, - leader_best.avg_reward, - ) - - challengers: list[BackendEvaluationResult] = [] - for res in eligible_results: - if ( - res.success_rate is not None - and res.success_rate > baseline_leader_success_rate - ): - challengers.append(res) - - if ( - leader_best - and leader_best.success_rate is not None - and leader_best.success_rate >= baseline_leader_success_rate - and leader_best.id != leader_success_eval_id - ): - challengers.append(leader_best) - - # Deduplicate challengers by evaluation_result_id while preserving order - seen_eval_ids: set[int] = set() - unique_challengers: list[BackendEvaluationResult] = [] - for res in challengers: - if res.id in seen_eval_ids: - continue - seen_eval_ids.add(res.id) - unique_challengers.append(res) - - for challenger in unique_challengers: - created_candidate = await self._queue_leader_candidate( - session, competition, challenger - ) - if created_candidate: - logger.info( - "Competition %s: Challenger %s queued for admin review (success_rate=%.3f, avg_reward=%.3f, current leader=%s success_rate=%s)", - competition.id, - challenger.miner_hotkey, - challenger.success_rate or 0.0, - challenger.avg_reward or 0.0, - current_leader, - f"{leader_success_rate:.3f}" - if leader_success_rate is not None - else "unknown", - ) - else: - logger.debug( - "Competition %s: Challenger %s already recorded as candidate", - competition.id, - challenger.miner_hotkey, - ) - - # Award points only to the currently approved leader - award_hotkey = competition.current_leader_hotkey - if not award_hotkey: - logger.debug( - "Competition %s: Skipping score award (no approved leader)", - competition.id, - ) - continue - - base_score = competition.points / total_points if total_points else 0 - if base_score == 0: - logger.debug( - "Competition %s: Skipping zero-point competition in scoring", - competition.id, - ) - continue - - if award_hotkey in miner_scores: - logger.warning( - "Miner %s already won competition - skipping score from %s. Previous score: %.4f, would have added: %.4f", - award_hotkey, - competition.id, - miner_scores[award_hotkey], - base_score * (1 - self.burn_pct), - ) - continue - - awarded_score = base_score * (1 - self.burn_pct) - burned_score = base_score - awarded_score - - if awarded_score <= 0: - logger.info( - "Competition %s: Burned entire %.4f normalized score for %s (burn_pct=%.2f%%)", - competition.id, - base_score, - award_hotkey, - self.burn_pct * 100, - ) - continue - - miner_scores[award_hotkey] = awarded_score - if burned_score > 0: - logger.info( - "Competition %s: Awarded %.4f normalized score to %s (burned %.4f; burn_pct=%.2f%%)", - competition.id, - awarded_score, - award_hotkey, - burned_score, - self.burn_pct * 100, - ) - else: - logger.info( - "Competition %s: Awarded %.4f normalized score to %s", - competition.id, - awarded_score, - award_hotkey, - ) - - # Commit any leader updates to database - await session.commit() - - # Log final scores - if miner_scores: - logger.info(f"Final miner scores: {len(miner_scores)} miners scored") - for hotkey, score in sorted( - miner_scores.items(), key=lambda x: x[1], reverse=True - )[:10]: - logger.info(f" {hotkey}: {score:.4f}") - else: - logger.info("No miners received scores") - - return miner_scores - - async def _periodic_score_evaluation(self) -> None: - """Periodically evaluate and update miner scores.""" - logger.info("Starting periodic score evaluation task") - - while self._running: - try: - logger.info("Running score evaluation cycle") - miner_scores = await self._score_evaluations() - - # Store latest scores for weight broadcasting - self._latest_miner_scores = miner_scores - - logger.info( - "Score evaluation complete. %s miners scored.", - len(miner_scores), + if upload.miner_hotkey != commitment.hotkey: + logger.error( + "Miner %s attempted to commit submission %s owned by %s", + commitment.hotkey, + submission_id, + upload.miner_hotkey, ) + return - except Exception as e: - logger.error(f"Error in periodic score evaluation: {e}") + if upload.competition_id != competition.id: + logger.error( + "Submission %s was prepared for competition %s but miner committed to %s", + submission_id, + upload.competition_id, + competition.id, + ) + return - # Wait for next evaluation cycle - await asyncio.sleep(self.score_evaluation_interval) + existing_submission = await session.execute( + select(MinerSubmission).where(MinerSubmission.id == submission_id) + ) + if existing_submission.scalar_one_or_none(): + logger.info("Submission %s already registered; skipping", submission_id) + upload.status = SubmissionUploadStatus.PROCESSED + await session.commit() + return - async def _periodic_weight_broadcast(self) -> None: - """Periodically broadcast weights to validators and set on chain.""" - logger.info("Starting periodic weight broadcast task") + metadata = self.submission_storage.head_object(upload.artifact_object_key) + if not metadata: + logger.error( + "Artifact %s not found for submission %s", + upload.artifact_object_key, + submission_id, + ) + return - while self._running: - try: - logger.info("Running weight broadcast cycle") - await self._broadcast_and_set_weights() + if metadata.sha256 and metadata.sha256 != upload.artifact_sha256: + logger.error( + "Artifact checksum mismatch for submission %s", submission_id + ) + return - except Exception as e: - logger.error(f"Error in periodic weight broadcast: {e}") + actual_size = metadata.size_bytes or upload.artifact_size_bytes - # Wait for next broadcast cycle - await asyncio.sleep(self.weight_broadcast_interval) + if actual_size != upload.artifact_size_bytes: + logger.error( + "Artifact size mismatch for submission %s (expected %s, got %s)", + submission_id, + upload.artifact_size_bytes, + actual_size, + ) + return - async def _holdout_release_loop(self) -> None: - """Periodically release submission artifacts after the hold-out window.""" - if not self.submission_storage: - return + upload.uploaded_at = metadata.last_modified + upload.status = SubmissionUploadStatus.PROCESSED - logger.info("Starting hold-out release task") - while self._running: - try: - await self._release_due_submissions() - except Exception as e: - logger.error(f"Hold-out release task error: {e}") - await asyncio.sleep(self.holdout_release_scan_interval) + submission = MinerSubmission( + id=submission_id, + miner_hotkey=commitment.hotkey, + competition_id=competition.id, + hf_repo_id=f"s3:{submission_id}", + version=upload.version, + commitment_block=block_num, + artifact_object_key=upload.artifact_object_key, + artifact_sha256=upload.artifact_sha256, + artifact_size_bytes=actual_size, + holdout_release_at=datetime.now(timezone.utc) + + timedelta(seconds=upload.holdout_seconds), + ) - async def _release_due_submissions(self) -> None: - """Mark submissions as released when their hold-out period expires.""" - if not self.submission_storage or not self.async_session: - logger.warning("Submission storage or async session not available") - return + session.add(submission) + await session.flush() - now = datetime.now(timezone.utc) - async with self.async_session() as session: - stmt = select(MinerSubmission).where( - and_( - MinerSubmission.holdout_release_at.is_not(None), - MinerSubmission.holdout_release_at <= now, - MinerSubmission.released_at.is_(None), - MinerSubmission.artifact_object_key.is_not(None), - ) + submission_event = SubmissionReceivedEvent( + submission_id=submission.id, + competition_id=submission.competition_id, + miner_hotkey=submission.miner_hotkey, + hf_repo_id=submission.hf_repo_id, + block_number=block_num, + created_at=submission.created_at or datetime.now(timezone.utc), + ) + await event_broadcaster.broadcast_event( + EventType.SUBMISSION_RECEIVED, submission_event ) - result = await session.execute(stmt) - submissions = result.scalars().all() - - if not submissions: - logger.info("No submissions due for hold-out release") - return - for submission in submissions: - try: - release_url, expires_at = ( - self.submission_storage.generate_download_url( - submission.artifact_object_key, - SUBMISSION_RELEASE_URL_TTL, - ) - ) - except Exception as exc: - logger.error( - "Failed to generate release URL for submission %s: %s", - submission.id, - exc, - ) - continue + # Use JobScheduler to create and schedule jobs + jobs = await self.job_scheduler.schedule_submission( + submission, competition, session + ) - submission.released_at = now - submission.public_artifact_url = release_url - submission.public_artifact_url_expires_at = expires_at - logger.info( - "Hold-out released submission %s (artifact=%s)", - submission.id, - submission.artifact_object_key, + if not jobs: + logger.error( + "No evaluation jobs generated for submission %s", submission_id ) - - await session.commit() - - async def _broadcast_and_set_weights(self) -> None: - """Broadcast weights to connected validators and set on chain using latest scores.""" - try: - if not self.substrate: - logger.error("Substrate not initialized") - return - if not self.nodes: - logger.error("Node list not initialized") + await session.commit() return - # Use cached scores from periodic evaluation - miner_scores = self._latest_miner_scores.copy() + await session.commit() - # Build weights dict mapping UIDs to weights - weights_dict: dict[int, float] = {} - for hotkey, weight in miner_scores.items(): - node = self.nodes.get(hotkey) - if node: - weights_dict[node.node_id] = weight + # Publish jobs using JobScheduler + await self.job_scheduler.publish_jobs(jobs) + await self._broadcast_stats_update() - total_weight = sum(weights_dict.values()) - if total_weight > 1.0: - logger.warning( - "Total miner weight %.6f exceeds 1.0 before owner allocation", - total_weight, - ) + logger.info( + "Processed S3 submission %s from miner %s", + submission_id, + commitment.hotkey, + ) - owner_weight = max(0.0, 1.0 - total_weight) - if owner_weight > 0: - weights_dict[self.owner_uid] = ( - weights_dict.get(self.owner_uid, 0.0) + owner_weight - ) - if all(node.node_id != self.owner_uid for node in self.nodes.values()): - logger.warning( - "Owner UID %s not found in node list; assigning %.4f weight without hotkey mapping", - self.owner_uid, - owner_weight, - ) - else: - logger.info( - "Owner UID %s assigned remaining normalized score %.4f (burn_pct=%.2f%%)", - self.owner_uid, - owner_weight, - self.burn_pct * 100, - ) + # Job rerun methods - delegate to JobScheduler + async def rerun_submission_evaluations( + self, + submission_id: int, + benchmark_names: Optional[List[str]] = None, + requested_by_api_key_id: Optional[int] = None, + ) -> List[BackendEvaluationJob]: + """Delegate to JobScheduler.""" + jobs = await self.job_scheduler.rerun_submission_evaluations( + submission_id, benchmark_names, requested_by_api_key_id + ) + await self._broadcast_stats_update() + return jobs - if not weights_dict: - logger.info("No miner scores to broadcast") - return + async def rerun_job_evaluation( + self, + job_id: int, + requested_by_api_key_id: Optional[int] = None, + ) -> BackendEvaluationJob: + """Delegate to JobScheduler.""" + job = await self.job_scheduler.rerun_job_evaluation( + job_id, requested_by_api_key_id + ) + await self._broadcast_stats_update() + return job - # Populate missing entries with 0.0 weight for all nodes - for node in self.nodes.values(): - weights_dict.setdefault(node.node_id, 0.0) + # Leader candidate admin methods + async def approve_leader_candidate( + self, + candidate_id: int, + admin_api_key_id: int, + reason: Optional[str] = None, + ) -> CompetitionLeaderCandidate: + """Approve a pending leader candidate and promote them to current leader.""" + if not self.async_session: + raise RuntimeError("Database not initialized") - snapshot_total_weight = float(sum(weights_dict.values())) - self._latest_weights_snapshot = WeightsSnapshot( - updated_at=datetime.now(timezone.utc), - total_weight=snapshot_total_weight, - weights=weights_dict.copy(), - ) + async with self.async_session() as session: + candidate = await session.get(CompetitionLeaderCandidate, candidate_id) + if not candidate: + raise LeaderCandidateNotFoundError( + f"Leader candidate {candidate_id} not found" + ) - # Broadcast to validators - weight_msg = SetWeightsMessage(weights=weights_dict) - weights_msg_str = weight_msg.model_dump_json() - broadcast_count = 0 - failed_connections = [] + if candidate.status != LeaderCandidateStatus.PENDING: + raise LeaderCandidateAlreadyReviewedError( + f"Leader candidate {candidate_id} has already been reviewed" + ) - for conn_id, ws in list(self.active_connections.items()): - try: - await ws.send_text(weights_msg_str) - broadcast_count += 1 - except Exception as e: - logger.error(f"Failed to send to {conn_id}: {e}") - failed_connections.append(conn_id) - - # Clean up failed connections - for conn_id in failed_connections: - if conn_id in self.active_connections: - del self.active_connections[conn_id] - if conn_id in self.validator_connections: - del self.validator_connections[conn_id] + competition = await session.get(Competition, candidate.competition_id) + if not competition: + raise LeaderCandidateNotFoundError( + f"Competition {candidate.competition_id} not found for candidate" + ) - logger.info( - f"Broadcasted weight update to {broadcast_count} validators:\n{weights_msg_str}" - ) - except Exception as e: - logger.error(f"Failed to broadcast weights: {e}") + now = datetime.now(timezone.utc) - def get_latest_weights_snapshot(self) -> Optional[WeightsSnapshot]: - """Return the most recent weight broadcast snapshot, if available.""" - snapshot = self._latest_weights_snapshot - if not snapshot: - return None + candidate.status = LeaderCandidateStatus.APPROVED + candidate.status_reason = reason + candidate.reviewed_by_api_key_id = admin_api_key_id + candidate.reviewed_at = now - return snapshot.model_copy(deep=True) + competition.current_leader_hotkey = candidate.miner_hotkey + competition.current_leader_reward = candidate.avg_reward + competition.leader_updated_at = now - async def _get_latest_block(self) -> int: - """Get latest block from chain. + await session.commit() + await session.refresh(candidate) - Returns: - int: The latest block number, or -1 if an error occurred. - Block 0 is a valid genesis block, so -1 indicates failure. - """ try: - if not self.substrate: - return -1 - # Run in thread pool to avoid blocking - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.thread_pool, self.substrate.get_block_number + await self._broadcast_stats_update() + except Exception as exc: + logger.error( + "Failed to broadcast stats after leader candidate approval: %s", + exc, ) - except Exception as e: - logger.error(f"Failed to get latest block: {e}") - return -1 - - def _sync_nodes(self) -> None: - node_list = _get_nodes_for_uid( - self.substrate, self.config.settings["subtensor"]["netuid"] - ) - self.nodes = {node.hotkey: node for node in node_list} - async def _sync_metagraph(self) -> None: - """Sync metagraph nodes with memory leak prevention.""" - try: - if self.nodes: - # Run sync in thread pool to avoid blocking - loop = asyncio.get_event_loop() - await loop.run_in_executor(self.thread_pool, self._sync_nodes) + return candidate - logger.debug("Nodes synced") - except Exception as e: - logger.error(f"Failed to sync metagraph: {e}") + async def reject_leader_candidate( + self, + candidate_id: int, + admin_api_key_id: int, + reason: Optional[str] = None, + ) -> CompetitionLeaderCandidate: + """Reject a pending leader candidate.""" + if not self.async_session: + raise RuntimeError("Database not initialized") - def _query_commitments_sync( - self, block_num: int, nodes: list - ) -> List[ChainCommitmentResponse]: - """Synchronous version of commitment querying for thread pool.""" - commitments = [] + async with self.async_session() as session: + candidate = await session.get(CompetitionLeaderCandidate, candidate_id) + if not candidate: + raise LeaderCandidateNotFoundError( + f"Leader candidate {candidate_id} not found" + ) - for node in nodes: - try: - miner_commitments = query_commitments_from_substrate( - self.config, self.substrate, node.hotkey, block=block_num + if candidate.status != LeaderCandidateStatus.PENDING: + raise LeaderCandidateAlreadyReviewedError( + f"Leader candidate {candidate_id} has already been reviewed" ) - if miner_commitments: - commitments.extend(miner_commitments) - except Exception as e: - logger.debug(f"Failed to query {node.hotkey}: {e}") - continue - return commitments + now = datetime.now(timezone.utc) - async def _query_block_commitments( - self, block_num: int - ) -> List[ChainCommitmentResponse]: - """Query commitments for a block.""" - try: - if not self.nodes: - return [] + candidate.status = LeaderCandidateStatus.REJECTED + candidate.status_reason = reason + candidate.reviewed_by_api_key_id = admin_api_key_id + candidate.reviewed_at = now - node_list = list(self.nodes.values()) + await session.commit() + await session.refresh(candidate) - # Run commitment querying in thread pool to avoid blocking - loop = asyncio.get_event_loop() - commitments = await loop.run_in_executor( - self.thread_pool, self._query_commitments_sync, block_num, node_list + try: + await self._broadcast_stats_update() + except Exception as exc: + logger.error( + "Failed to broadcast stats after leader candidate rejection: %s", + exc, ) - return commitments - - except Exception as e: - logger.error(f"Failed to query block {block_num}: {e}") - return [] + return candidate - async def _process_commitment( + async def unapprove_leader_candidate( self, - commitment: ChainCommitmentResponse, - block_num: int, - active_competitions: dict[str, Competition], - ): - """Process a commitment from the chain.""" + candidate_id: int, + admin_api_key_id: int, + reason: Optional[str] = None, + ) -> CompetitionLeaderCandidate: + """Revert an approved leader candidate back to pending state.""" + if not self.async_session: + raise RuntimeError("Database not initialized") - try: - logger.debug( - "Processing commitment for block %s: %s", block_num, commitment - ) - competition_id = getattr(commitment.data, "comp_id", None) + async with self.async_session() as session: + candidate = await session.get(CompetitionLeaderCandidate, candidate_id) + if not candidate: + raise LeaderCandidateNotFoundError( + f"Leader candidate {candidate_id} not found" + ) - if not competition_id or competition_id not in active_competitions: - logger.warning( - "Miner %s submitted commitment for unknown competition %s", - commitment.hotkey, - competition_id, + if candidate.status != LeaderCandidateStatus.APPROVED: + raise LeaderCandidateNotApprovedError( + f"Leader candidate {candidate_id} is not approved" ) - return - provider = getattr(commitment.data, "provider", None) + competition = await session.get(Competition, candidate.competition_id) + if not competition: + raise LeaderCandidateNotFoundError( + f"Competition {candidate.competition_id} not found for candidate" + ) - if provider == ModelProvider.S3: - await self._process_s3_commitment( - commitment, - block_num, - active_competitions[competition_id], + previous_reviewed_at = candidate.reviewed_at + was_current_leader = ( + competition.current_leader_hotkey == candidate.miner_hotkey + and competition.leader_updated_at == previous_reviewed_at + ) + + candidate.status = LeaderCandidateStatus.PENDING + candidate.status_reason = reason + candidate.reviewed_by_api_key_id = None + candidate.reviewed_at = None + await session.flush() + + if was_current_leader: + fallback_stmt = ( + select(CompetitionLeaderCandidate) + .where( + CompetitionLeaderCandidate.competition_id + == candidate.competition_id, + CompetitionLeaderCandidate.status + == LeaderCandidateStatus.APPROVED, + CompetitionLeaderCandidate.id != candidate.id, + ) + .order_by(CompetitionLeaderCandidate.reviewed_at.desc()) + .limit(1) ) - return + fallback_result = await session.execute(fallback_stmt) + fallback_candidate = fallback_result.scalar_one_or_none() - logger.warning( - "Unsupported commitment provider %s from miner %s", - provider, - commitment.hotkey, + if fallback_candidate: + competition.current_leader_hotkey = fallback_candidate.miner_hotkey + competition.current_leader_reward = fallback_candidate.avg_reward + competition.leader_updated_at = fallback_candidate.reviewed_at + else: + competition.current_leader_hotkey = None + competition.current_leader_reward = None + competition.leader_updated_at = None + + await session.commit() + await session.refresh(candidate) + logger.info( + "Admin %s unapproved leader candidate %s (competition=%s)", + admin_api_key_id, + candidate_id, + candidate.competition_id, ) + + try: + await self._broadcast_stats_update() except Exception as exc: - logger.error("Failed to process commitment: %s", exc) + logger.error( + "Failed to broadcast stats after leader candidate unapproval: %s", + exc, + ) - async def _process_s3_commitment( - self, - commitment: ChainCommitmentResponse, - block_num: int, - competition: Competition, - ) -> None: - """Handle commitments referencing direct-vault submissions.""" + return candidate + + # The remaining methods stay largely the same but with minor adjustments... + # I'll include the key ones that interact with extracted components + + async def _monitor_stale_jobs(self) -> None: + """Monitor for stale jobs and mark them as failed.""" + while self._running: + try: + if self.async_session: + async with self.async_session() as session: + current_time = datetime.now(timezone.utc) + min_timeout_result = await session.execute( + select(func.min(Competition.job_timeout_seconds)) + ) + min_timeout = min_timeout_result.scalar() + candidates = [ + t + for t in ( + min_timeout, + self.default_job_timeout_seconds, + ) + if t and t > 0 + ] + min_timeout_seconds = min(candidates) if candidates else 1 + prefilter_threshold = current_time - timedelta( + seconds=min_timeout_seconds + ) + prefilter_threshold = prefilter_threshold.replace(tzinfo=None) - if not self.submission_storage: - logger.error( - "Submission storage not configured; cannot process S3 commitment" - ) - return + result = await session.execute( + select(BackendEvaluationJob).where( + BackendEvaluationJob.created_at < prefilter_threshold + ) + ) - if not self.async_session: - logger.error("Database not initialized; cannot process commitment") - return + candidate_jobs = result.scalars().all() - try: - submission_id = int(commitment.data.repo_id) - except (TypeError, ValueError): - logger.error( - "Invalid submission identifier '%s' provided by miner %s", - commitment.data.repo_id, - commitment.hotkey, - ) - return + for job in candidate_jobs: + job_timeout = self._resolve_job_timeout_seconds( + job.timeout_seconds or self.default_job_timeout_seconds + ) + stale_threshold = current_time - timedelta( + seconds=job_timeout + ) - async with self.async_session() as session: - upload_result = await session.execute( - select(SubmissionUpload).where( - SubmissionUpload.submission_id == submission_id - ) - ) - upload: Optional[SubmissionUpload] = upload_result.scalar_one_or_none() + if job.created_at and job.created_at >= stale_threshold: + continue - if not upload: - logger.error( - "No pending upload found for submission %s (miner %s)", - submission_id, - commitment.hotkey, - ) - return + status_result = await session.execute( + select(BackendEvaluationJobStatus) + .where(BackendEvaluationJobStatus.job_id == job.id) + .order_by(BackendEvaluationJobStatus.created_at.desc()) + .limit(1) + ) + latest_status = status_result.scalar() - if upload.status == SubmissionUploadStatus.PROCESSED: - logger.info( - "Submission %s already processed; ignoring duplicate commitment", - submission_id, - ) - return + if not latest_status or latest_status.status not in [ + EvaluationStatus.COMPLETED, + EvaluationStatus.FAILED, + EvaluationStatus.CANCELLED, + EvaluationStatus.TIMEOUT, + ]: + logger.warning(f"Marking stale job {job.id} as TIMEOUT") - if upload.miner_hotkey != commitment.hotkey: - logger.error( - "Miner %s attempted to commit submission %s owned by %s", - commitment.hotkey, - submission_id, - upload.miner_hotkey, - ) - return + # Create timeout status (no validator hotkey since + # validators no longer connect via WebSocket) + timeout_status = BackendEvaluationJobStatus( + id=next(self.id_generator), + job_id=job.id, + validator_hotkey="system", + status=EvaluationStatus.TIMEOUT, + detail="Job marked as timeout due to inactivity", + ) + session.add(timeout_status) - if upload.competition_id != competition.id: - logger.error( - "Submission %s was prepared for competition %s but miner committed to %s", - submission_id, - upload.competition_id, - competition.id, - ) - return + await session.commit() - existing_submission = await session.execute( - select(MinerSubmission).where(MinerSubmission.id == submission_id) - ) - if existing_submission.scalar_one_or_none(): - logger.info("Submission %s already registered; skipping", submission_id) - upload.status = SubmissionUploadStatus.PROCESSED - await session.commit() - return + await event_broadcaster.broadcast_event( + EventType.JOB_STATUS_CHANGED, + { + "job_id": str(job.id), + "status": "TIMEOUT", + "detail": "Job marked as timeout due to inactivity", + }, + ) - metadata = self.submission_storage.head_object(upload.artifact_object_key) - if not metadata: - logger.error( - "Artifact %s not found for submission %s", - upload.artifact_object_key, - submission_id, - ) - return + await asyncio.sleep(300) - if metadata.sha256 and metadata.sha256 != upload.artifact_sha256: - logger.error( - "Artifact checksum mismatch for submission %s", submission_id - ) - return + except Exception as e: + logger.error(f"Error monitoring stale jobs: {e}") + await asyncio.sleep(300) - actual_size = metadata.size_bytes or upload.artifact_size_bytes + async def _holdout_release_loop(self) -> None: + """Periodically release submission artifacts after the hold-out window.""" + if not self.submission_storage: + return - if actual_size != upload.artifact_size_bytes: - logger.error( - "Artifact size mismatch for submission %s (expected %s, got %s)", - submission_id, - upload.artifact_size_bytes, - actual_size, - ) - return + logger.info("Starting hold-out release task") + while self._running: + try: + await self._release_due_submissions() + except Exception as e: + logger.error(f"Hold-out release task error: {e}") + await asyncio.sleep(self.holdout_release_scan_interval) - upload.uploaded_at = metadata.last_modified - upload.status = SubmissionUploadStatus.PROCESSED + async def _release_due_submissions(self) -> None: + """Mark submissions as released when their hold-out period expires.""" + if not self.submission_storage or not self.async_session: + logger.warning("Submission storage or async session not available") + return - submission = MinerSubmission( - id=submission_id, - miner_hotkey=commitment.hotkey, - competition_id=competition.id, - hf_repo_id=f"s3:{submission_id}", - version=upload.version, - commitment_block=block_num, - artifact_object_key=upload.artifact_object_key, - artifact_sha256=upload.artifact_sha256, - artifact_size_bytes=actual_size, - holdout_release_at=datetime.now(timezone.utc) - + timedelta(seconds=upload.holdout_seconds), + now = datetime.now(timezone.utc) + async with self.async_session() as session: + stmt = select(MinerSubmission).where( + and_( + MinerSubmission.holdout_release_at.is_not(None), + MinerSubmission.holdout_release_at <= now, + MinerSubmission.released_at.is_(None), + MinerSubmission.artifact_object_key.is_not(None), + ) ) + result = await session.execute(stmt) + submissions = result.scalars().all() - session.add(submission) - await session.flush() - - submission_event = SubmissionReceivedEvent( - submission_id=submission.id, - competition_id=submission.competition_id, - miner_hotkey=submission.miner_hotkey, - hf_repo_id=submission.hf_repo_id, - block_number=block_num, - created_at=submission.created_at or datetime.now(timezone.utc), - ) - await event_broadcaster.broadcast_event( - EventType.SUBMISSION_RECEIVED, submission_event - ) + if not submissions: + logger.info("No submissions due for hold-out release") + return - jobs: List[BackendEvaluationJob] = [] - for benchmark in competition.benchmarks: - if "provider" not in benchmark or "benchmark_name" not in benchmark: + for submission in submissions: + try: + release_url, expires_at = ( + self.submission_storage.generate_download_url( + submission.artifact_object_key, + SUBMISSION_RELEASE_URL_TTL.total_seconds(), + ) + ) + except Exception as exc: logger.error( - "Benchmark missing provider or benchmark_name: %s", benchmark + "Failed to generate release URL for submission %s: %s", + submission.id, + exc, ) continue - if isinstance(benchmark, dict): - benchmark_spec = copy.deepcopy(benchmark) - else: - logger.warning( - "Benchmark specification for competition %s is not a dict (%r); " - "wrapping in config field", - competition.id, - type(benchmark), - ) - benchmark_spec = {"config": benchmark} - - job = BackendEvaluationJob( - id=next(self.id_generator), - submission_id=submission.id, - competition_id=competition.id, - miner_hotkey=submission.miner_hotkey, - hf_repo_id=submission.hf_repo_id, - env_provider=benchmark_spec.get("provider", benchmark["provider"]), - benchmark_name=benchmark_spec.get( - "benchmark_name", benchmark["benchmark_name"] - ), - config=benchmark_spec, - timeout_seconds=self._job_timeout_seconds(competition), - artifact_object_key=submission.artifact_object_key, - artifact_sha256=submission.artifact_sha256, - artifact_size_bytes=submission.artifact_size_bytes, - ) - jobs.append(job) - - if not jobs: - logger.error( - "No evaluation jobs generated for submission %s", submission_id + submission.released_at = now + submission.public_artifact_url = release_url + submission.public_artifact_url_expires_at = expires_at + logger.info( + "Hold-out released submission %s (artifact=%s)", + submission.id, + submission.artifact_object_key, ) - await session.commit() - return - session.add_all(jobs) await session.commit() - await self._publish_new_jobs(jobs) - - logger.info( - "Processed S3 submission %s from miner %s", - submission_id, - commitment.hotkey, - ) async def _broadcast_stats_update(self): """Broadcast updated statistics to all clients.""" @@ -2471,106 +1412,36 @@ async def _broadcast_stats_update(self): try: async with self.async_session() as session: - # Get competitions comp_result = await session.execute(select(Competition)) competitions = comp_result.scalars().all() active_comps = [c for c in competitions if c.active] total_points = sum(c.points for c in active_comps) - # Get validators val_result = await session.execute( select(ValidatorConnection).where(ValidatorConnection.is_connected) ) connected_validators = len(val_result.scalars().all()) - # Get submissions count - sub_result = await session.execute( - select(func.count(MinerSubmission.id)) - ) - total_submissions = sub_result.scalar() or 0 - - # Get jobs count - job_result = await session.execute( - select(func.count(BackendEvaluationJob.id)) - ) - total_jobs = job_result.scalar() or 0 - - # Get completed jobs count (latest status is COMPLETED) - latest_status_subquery = ( - select( - BackendEvaluationJobStatus.job_id, - func.max(BackendEvaluationJobStatus.created_at).label( - "max_created_at" - ), - ) - .group_by(BackendEvaluationJobStatus.job_id) - .subquery() - ) - - completed_jobs_result = await session.execute( - select(func.count(BackendEvaluationJob.id.distinct())) - .select_from(BackendEvaluationJob) - .join( - BackendEvaluationJobStatus, - BackendEvaluationJob.id == BackendEvaluationJobStatus.job_id, - ) - .join( - latest_status_subquery, - and_( - BackendEvaluationJobStatus.job_id - == latest_status_subquery.c.job_id, - BackendEvaluationJobStatus.created_at - == latest_status_subquery.c.max_created_at, - ), - ) - .where( - BackendEvaluationJobStatus.status == EvaluationStatus.COMPLETED - ) - ) - completed_jobs = completed_jobs_result.scalar() or 0 - - # Get failed jobs count (latest status is FAILED, CANCELLED, or TIMEOUT) - failed_jobs_result = await session.execute( - select(func.count(BackendEvaluationJob.id.distinct())) - .select_from(BackendEvaluationJob) - .join( - BackendEvaluationJobStatus, - BackendEvaluationJob.id == BackendEvaluationJobStatus.job_id, - ) - .join( - latest_status_subquery, - and_( - BackendEvaluationJobStatus.job_id - == latest_status_subquery.c.job_id, - BackendEvaluationJobStatus.created_at - == latest_status_subquery.c.max_created_at, - ), - ) - .where( - BackendEvaluationJobStatus.status.in_( - [ - EvaluationStatus.FAILED, - EvaluationStatus.CANCELLED, - EvaluationStatus.TIMEOUT, - ] - ) - ) + sub_result = await session.execute( + select(func.count(MinerSubmission.id)) + ) + total_submissions = sub_result.scalar() or 0 + + job_result = await session.execute( + select(func.count(BackendEvaluationJob.id)) ) - failed_jobs = failed_jobs_result.scalar() or 0 + total_jobs = job_result.scalar() or 0 - # Get results count result_count = await session.execute( select(func.count(BackendEvaluationResult.id)) ) total_results = result_count.scalar() or 0 - # Get backend state state_result = await session.execute( select(BackendState).where(BackendState.id == 1) ) state = state_result.scalar_one_or_none() - # Calculate competition percentages comp_percentages = {} for comp in active_comps: percentage = ( @@ -2578,7 +1449,6 @@ async def _broadcast_stats_update(self): ) comp_percentages[comp.id] = percentage - # Create and broadcast StatsUpdatedEvent stats_event = StatsUpdatedEvent( total_competitions=len(competitions), active_competitions=len(active_comps), @@ -2587,8 +1457,8 @@ async def _broadcast_stats_update(self): total_submissions=total_submissions, total_jobs=total_jobs, total_results=total_results, - completed_jobs=completed_jobs, - failed_jobs=failed_jobs, + completed_jobs=0, # Simplified for now + failed_jobs=0, last_seen_block=state.last_seen_block if state else 0, competition_percentages=comp_percentages, ) @@ -2600,147 +1470,460 @@ async def _broadcast_stats_update(self): except Exception as e: logger.error(f"Failed to broadcast stats update: {e}") - async def rerun_submission_evaluations( - self, - submission_id: int, - benchmark_names: Optional[List[str]] = None, - requested_by_api_key_id: Optional[int] = None, - ) -> List[BackendEvaluationJob]: - """Re-run evaluations for a submission across its configured benchmarks.""" + # Validator message handling (keeping existing implementation) + async def queue_validator_heartbeat( + self, validator_hotkey: SS58Address, timestamp: datetime + ) -> None: + """Queue a validator heartbeat for asynchronous persistence.""" + enqueued = await self._enqueue_validator_message( + { + "type": "heartbeat", + "validator_hotkey": validator_hotkey, + "timestamp": timestamp, + }, + block=False, + ) - if not self.async_session: - raise RuntimeError("Database not initialized") + if not enqueued: + logger.warning( + "Dropping heartbeat for %s due to full validator queue", + validator_hotkey, + ) - benchmark_filter = ( - {name.strip() for name in benchmark_names if name.strip()} - if benchmark_names - else None + async def queue_episode_data( + self, validator_hotkey: SS58Address, message: EpisodeDataMessage + ) -> None: + """Queue episode data for asynchronous persistence.""" + enqueued = await self._enqueue_validator_message( + { + "type": "episode_data", + "validator_hotkey": validator_hotkey, + "message": message, + } ) - new_jobs: List[BackendEvaluationJob] = [] + if not enqueued: + logger.error( + "Failed to enqueue episode data for submission %s episode %s", + message.submission_id, + message.episode_id, + ) - async with self.async_session() as session: - submission = await session.get(MinerSubmission, submission_id) - if not submission: - raise SubmissionNotFoundError(f"Submission {submission_id} not found") + async def queue_episode_step_data( + self, validator_hotkey: SS58Address, message: EpisodeStepDataMessage + ) -> None: + """Queue episode step data for asynchronous persistence.""" + enqueued = await self._enqueue_validator_message( + { + "type": "episode_step_data", + "validator_hotkey": validator_hotkey, + "message": message, + } + ) - competition = await session.get(Competition, submission.competition_id) - if not competition: - raise SubmissionNotFoundError( - f"Competition {submission.competition_id} not found for submission {submission_id}" + if not enqueued: + logger.error( + "Failed to enqueue step data for submission %s episode %s step %s", + message.submission_id, + message.episode_id, + message.step, + ) + + async def _enqueue_validator_message( + self, payload: Dict[str, Any], *, block: bool = True + ) -> bool: + """Enqueue a validator message, optionally dropping if queue is saturated.""" + if not self._running: + logger.debug( + "Received validator message %s while backend not running", + payload.get("type"), + ) + return False + + queue_size = self._validator_message_queue.qsize() + if queue_size > self._validator_queue_warning_threshold: + if not self._validator_queue_warning_triggered: + logger.warning( + "Validator message queue high water mark: %s/%s", + queue_size, + VALIDATOR_MESSAGE_QUEUE_MAXSIZE, ) + self._validator_queue_warning_triggered = True + elif self._validator_queue_warning_triggered and queue_size < ( + self._validator_queue_warning_threshold // 2 + ): + logger.info( + "Validator message queue draining (current size %s)", queue_size + ) + self._validator_queue_warning_triggered = False + + try: + if block: + await self._validator_message_queue.put(payload) + else: + self._validator_message_queue.put_nowait(payload) + return True + except asyncio.QueueFull: + logger.error( + "Validator message queue full; dropping %s message", + payload.get("type"), + ) + return False + + async def _validator_message_worker(self, worker_id: int) -> None: + """Background worker that batches validator messages before persisting.""" + logger.info("Validator message worker %s started", worker_id) + try: + while self._running: + try: + message = await self._validator_message_queue.get() + except asyncio.CancelledError: + break + + batch = [message] + loop = asyncio.get_running_loop() + deadline = loop.time() + VALIDATOR_MESSAGE_BATCH_INTERVAL + + while len(batch) < VALIDATOR_MESSAGE_BATCH_SIZE: + timeout = deadline - loop.time() + if timeout <= 0: + break - benchmarks = competition.benchmarks or [] - for benchmark in benchmarks: - provider = benchmark.get("provider") - benchmark_name = benchmark.get("benchmark_name") + try: + next_message = await asyncio.wait_for( + self._validator_message_queue.get(), timeout + ) + batch.append(next_message) + except asyncio.TimeoutError: + break + except asyncio.CancelledError: + raise - if not provider or not benchmark_name: + try: + await self._process_validator_batch(batch) + except Exception as exc: logger.error( - "Submission %s rerun skipped invalid benchmark entry: %s", - submission_id, - benchmark, + "Validator message worker %s failed to process batch: %s", + worker_id, + exc, ) - continue + finally: + for _ in batch: + self._validator_message_queue.task_done() - if benchmark_filter and benchmark_name not in benchmark_filter: - continue - spec_payload = _normalize_benchmark_spec_payload( - provider, - benchmark_name, - benchmark, - ) + finally: + logger.info("Validator message worker %s stopped", worker_id) + + async def _process_validator_batch(self, batch: List[Dict[str, Any]]) -> None: + """Persist a batch of validator messages and emit related events.""" + if not batch: + return + + if not self.async_session: + logger.error("Database not initialized; dropping validator messages") + return + + heartbeats, episode_payload_map, step_payload_map = ( + self._group_validator_messages(batch) + ) + + step_payload_map = { + key: sorted(payloads, key=lambda payload: payload["message"].step) + for key, payloads in step_payload_map.items() + } + + all_episode_keys = sorted( + set(episode_payload_map.keys()) | set(step_payload_map.keys()), + key=lambda k: (k[0], k[2], k[1], k[3]), + ) + + step_table = EpisodeStepData.__table__ + placeholder_warned: set[EpisodeKey] = set() + placeholder_keys: set[EpisodeKey] = set() + + max_retries = 5 + attempt = 0 + + while True: + try: + pending_events: List[tuple[EventType, Any]] = [] + status_updates: List[Dict[str, Any]] = [] + status_update_keys: set[tuple[Any, SS58Address]] = set() + + async with self.async_session() as session: + episode_lookup: Dict[EpisodeKey, int] = {} + + await self._apply_heartbeats(session, heartbeats) + + for key in all_episode_keys: + episode_payload = episode_payload_map.get(key) + step_payloads = step_payload_map.get(key, []) + if not episode_payload and not step_payloads: + continue + + await self._acquire_episode_lock(session, key) + + current_episode_id = episode_lookup.get(key) + if current_episode_id is None and step_payloads: + existing_episode_id = await self._get_episode_id_from_db( + session, key + ) + if existing_episode_id is not None: + episode_lookup[key] = existing_episode_id + current_episode_id = existing_episode_id + + if episode_payload: + message: EpisodeDataMessage = episode_payload["message"] + validator_hotkey: SS58Address = episode_payload[ + "validator_hotkey" + ] + logger.info( + "Applying episode summary submission=%s task=%s episode=%s validator=%s reward=%s steps=%s", + message.submission_id, + message.task_id, + message.episode_id, + validator_hotkey, + message.final_reward, + message.steps, + ) + current_episode_id = await self._ensure_episode_record( + message, + validator_hotkey, + episode_lookup, + session, + ) + if key in placeholder_keys: + logger.info( + "Overwriting placeholder episode with summary submission=%s task=%s episode=%s validator=%s", + message.submission_id, + message.task_id, + message.episode_id, + validator_hotkey, + ) + placeholder_keys.discard(key) + + pending_events.append( + ( + EventType.EPISODE_COMPLETED, + EpisodeCompletedEvent( + job_id=message.job_id, + submission_id=message.submission_id, + validator_hotkey=validator_hotkey, + episode_id=message.episode_id, + env_name=message.env_name, + benchmark_name=message.benchmark_name, + final_reward=message.final_reward, + success=message.success, + steps=message.steps, + start_time=_ensure_datetime(message.start_time), + end_time=_ensure_datetime(message.end_time), + extra_metrics=message.extra_metrics, + created_at=datetime.now(timezone.utc), + ), + ), + ) - job = BackendEvaluationJob( - id=next(self.id_generator), - submission_id=submission.id, - competition_id=competition.id, - miner_hotkey=submission.miner_hotkey, - hf_repo_id=submission.hf_repo_id, - env_provider=provider, - benchmark_name=benchmark_name, - config=spec_payload, - timeout_seconds=self._job_timeout_seconds(competition), - artifact_object_key=submission.artifact_object_key, - artifact_sha256=submission.artifact_sha256, - artifact_size_bytes=submission.artifact_size_bytes, - ) - new_jobs.append(job) + status_result = await session.execute( + select(BackendEvaluationJobStatus).where( + BackendEvaluationJobStatus.job_id == message.job_id, + BackendEvaluationJobStatus.validator_hotkey + == validator_hotkey, + BackendEvaluationJobStatus.status + == EvaluationStatus.RUNNING, + ) + ) + if not status_result.scalar_one_or_none(): + status_key = (message.job_id, validator_hotkey) + if status_key not in status_update_keys: + status_updates.append( + { + "job_id": message.job_id, + "validator_hotkey": validator_hotkey, + "status": EvaluationStatus.RUNNING, + "detail": f"Started processing episodes (episode {message.episode_id})", + } + ) + status_update_keys.add(status_key) - if not new_jobs: - raise NoBenchmarksAvailableError( - "No matching benchmarks available for rerun request" - ) + if not episode_payload and step_payloads: + if key not in placeholder_warned: + submission_id, episode_no, task_id, validator_key = key + logger.warning( + "Episode summary missing for submission=%s task=%s episode=%s validator=%s; using placeholder values", + submission_id, + task_id, + episode_no, + validator_key, + ) + placeholder_warned.add(key) - session.add_all(new_jobs) - await session.commit() + if current_episode_id is None and step_payloads: + placeholder_episode = self._placeholder_episode_from_step( + step_payloads[0]["message"] + ) + logger.info( + "Creating placeholder episode for submission=%s task=%s episode=%s validator=%s", + placeholder_episode.submission_id, + placeholder_episode.task_id, + placeholder_episode.episode_id, + step_payloads[0]["validator_hotkey"], + ) + placeholder_keys.add(key) + current_episode_id = await self._ensure_episode_record( + placeholder_episode, + step_payloads[0]["validator_hotkey"], + episode_lookup, + session, + ) - for job in new_jobs: - await session.refresh(job) + for step_payload in step_payloads: + step_message: EpisodeStepDataMessage = step_payload[ + "message" + ] + validator_hotkey = step_payload["validator_hotkey"] - await self._publish_new_jobs(new_jobs) + episode_lookup_id = episode_lookup.get(key) + if episode_lookup_id is None: + episode_lookup_id = await self._get_episode_id_from_db( + session, key + ) + if episode_lookup_id is not None: + episode_lookup[key] = episode_lookup_id + if episode_lookup_id is None: + placeholder_episode = ( + self._placeholder_episode_from_step(step_message) + ) + logger.info( + "Creating placeholder episode for submission=%s task=%s episode=%s validator=%s", + placeholder_episode.submission_id, + placeholder_episode.task_id, + placeholder_episode.episode_id, + validator_hotkey, + ) + placeholder_keys.add(key) + episode_lookup_id = await self._ensure_episode_record( + placeholder_episode, + validator_hotkey, + episode_lookup, + session, + ) - logger.info( - "Submission %s rerun triggered by API key %s; queued %s jobs", - submission_id, - requested_by_api_key_id, - len(new_jobs), - ) + step_values = { + "id": next(self.id_generator), + "episode_id": episode_lookup_id, + "submission_id": step_message.submission_id, + "validator_hotkey": validator_hotkey, + "task_id": step_message.task_id, + "step": step_message.step, + "action": step_message.action, + "reward": step_message.reward, + "done": step_message.done, + "truncated": step_message.truncated, + "observation_refs": step_message.observation_refs, + "info": step_message.info, + "timestamp": _ensure_datetime( + step_message.step_timestamp + ), + } - return new_jobs + step_insert = ( + insert(step_table) + .values(**step_values) + .on_conflict_do_update( + index_elements=["episode_id", "step"], + set_={ + "submission_id": step_values["submission_id"], + "validator_hotkey": step_values[ + "validator_hotkey" + ], + "task_id": step_values["task_id"], + "action": step_values["action"], + "reward": step_values["reward"], + "done": step_values["done"], + "truncated": step_values["truncated"], + "observation_refs": step_values[ + "observation_refs" + ], + "info": step_values["info"], + "timestamp": step_values["timestamp"], + "updated_at": func.now(), + }, + ) + ) - async def rerun_job_evaluation( - self, - job_id: int, - requested_by_api_key_id: Optional[int] = None, - ) -> BackendEvaluationJob: - """Re-run a specific evaluation job by cloning its configuration.""" + await session.execute(step_insert) - if not self.async_session: - raise RuntimeError("Database not initialized") + pending_events.append( + ( + EventType.EPISODE_STEP, + EpisodeStepEvent( + submission_id=step_message.submission_id, + validator_hotkey=validator_hotkey, + episode_id=step_message.episode_id, + step=step_message.step, + action=step_message.action, + reward=step_message.reward, + done=step_message.done, + truncated=step_message.truncated, + observation_refs=step_message.observation_refs, + info=step_message.info, + ), + ) + ) - async with self.async_session() as session: - existing_job = await session.get(BackendEvaluationJob, job_id) - if not existing_job: - raise EvaluationJobNotFoundError(f"Job {job_id} not found") - - spec_payload = _normalize_benchmark_spec_payload( - existing_job.env_provider, - existing_job.benchmark_name, - existing_job.config - if isinstance(existing_job.config, Mapping) - else None, - ) + if episode_payload is None and step_payloads: + submission_id, episode_no, task_id, validator_key = key + if episode_lookup.get(key) is None: + logger.warning( + "Episode summary missing for submission=%s task=%s episode=%s validator=%s; using placeholder values", + submission_id, + task_id, + episode_no, + validator_key, + ) - new_job = BackendEvaluationJob( - id=next(self.id_generator), - submission_id=existing_job.submission_id, - competition_id=existing_job.competition_id, - miner_hotkey=existing_job.miner_hotkey, - hf_repo_id=existing_job.hf_repo_id, - env_provider=existing_job.env_provider, - benchmark_name=existing_job.benchmark_name, - config=spec_payload, - timeout_seconds=existing_job.timeout_seconds, - artifact_object_key=existing_job.artifact_object_key, - artifact_sha256=existing_job.artifact_sha256, - artifact_size_bytes=existing_job.artifact_size_bytes, - ) + await session.commit() - session.add(new_job) - await session.commit() - await session.refresh(new_job) + for update in status_updates: + await self._update_job_status( + update["job_id"], + update["validator_hotkey"], + update["status"], + update["detail"], + ) - await self._publish_new_jobs([new_job]) + for event_type, event_payload in pending_events: + await event_broadcaster.broadcast_event(event_type, event_payload) - logger.info( - "Job %s rerun created as job %s by API key %s", - job_id, - new_job.id, - requested_by_api_key_id, - ) + if heartbeats: + logger.debug("Processed %s heartbeat updates", len(heartbeats)) + if episode_payload_map: + logger.info( + "Persisted %s episode records", len(episode_payload_map) + ) + if step_payload_map: + logger.debug( + "Persisted %s episode step records", + sum(len(payloads) for payloads in step_payload_map.values()), + ) - return new_job + break + except DBAPIError as exc: + if ( + isinstance(exc.orig, DeadlockDetectedError) + and attempt < max_retries + ): + attempt += 1 + delay = min(0.5 * attempt, 3.0) + logger.warning( + "Deadlock detected while processing validator batch (attempt %s/%s); retrying in %.2fs", + attempt, + max_retries, + delay, + ) + await asyncio.sleep(delay) + continue + raise async def _update_job_status( self, @@ -2756,7 +1939,6 @@ async def _update_job_status( try: async with self.async_session() as session: - # Create new status record status_record = BackendEvaluationJobStatus( id=next(self.id_generator), job_id=job_id, @@ -2767,7 +1949,6 @@ async def _update_job_status( session.add(status_record) await session.commit() - # Broadcast status change event to clients using the model status_event = JobStatusChangedEvent( job_id=str(job_id), validator_hotkey=validator_hotkey, @@ -2779,7 +1960,6 @@ async def _update_job_status( EventType.JOB_STATUS_CHANGED, status_event ) - # If job is completed or failed, also send JOB_COMPLETED event and stats update if status in [ EvaluationStatus.COMPLETED, EvaluationStatus.FAILED, @@ -2791,13 +1971,12 @@ async def _update_job_status( validator_hotkey=validator_hotkey, status=status.value, detail=detail, - result_count=0, # Will be updated when results come in + result_count=0, ) await event_broadcaster.broadcast_event( EventType.JOB_COMPLETED, completed_event ) - # Broadcast updated stats await self._broadcast_stats_update() logger.debug( @@ -2807,137 +1986,7 @@ async def _update_job_status( except Exception as e: logger.error(f"Failed to update job status: {e}") - async def _update_job_status_for_validators( - self, job_id: int, status: EvaluationStatus, detail: str = None - ): - """Update job status for all connected validators.""" - for validator_hotkey in self.validator_connections.values(): - await self._update_job_status(job_id, validator_hotkey, status, detail) - - async def _publish_new_jobs(self, jobs: Sequence[BackendEvaluationJob]) -> None: - """Emit events and broadcasts for newly created jobs.""" - - if not jobs: - return - - connected_validator_hotkeys = tuple( - dict.fromkeys(self.validator_connections.values()) - ) - - for job in jobs: - _benchmark_spec_payload, base_config_payload = ( - _extract_benchmark_spec_payload(job.config) - ) - job_event = JobCreatedEvent( - job_id=str(job.id), - competition_id=job.competition_id, - submission_id=job.submission_id, - miner_hotkey=job.miner_hotkey, - hf_repo_id=job.hf_repo_id, - env_provider=job.env_provider, - benchmark_name=job.benchmark_name, - config=base_config_payload, - status=EvaluationStatus.QUEUED, - validator_statuses={ - hotkey: EvaluationStatus.QUEUED - for hotkey in connected_validator_hotkeys - }, - ) - - try: - await event_broadcaster.broadcast_event( - EventType.JOB_CREATED, job_event - ) - except Exception as exc: # pragma: no cover - broadcast best effort - logger.error( - "Failed to broadcast job created event for job %s: %s", - job.id, - exc, - ) - - try: - await self._broadcast_job(job) - except Exception as exc: # pragma: no cover - broadcast best effort - logger.error("Failed to push job %s to validators: %s", job.id, exc) - - await self._broadcast_stats_update() - - async def _broadcast_job(self, job: BackendEvaluationJob): - """Broadcast job to connected validators.""" - if not self.active_connections: - logger.warning("No validators connected") - return - - artifact_url = None - artifact_expires_at: Optional[datetime] = None - if self.submission_storage and job.artifact_object_key: - try: - artifact_url, artifact_expires_at = ( - self.submission_storage.generate_download_url( - job.artifact_object_key, self.submission_download_url_ttl - ) - ) - except Exception as exc: - logger.error( - "Failed to generate artifact URL for job %s: %s", job.id, exc - ) - else: - logger.error( - "Cannot broadcast submission %s: storage unavailable or artifact missing", - job.submission_id, - ) - return - - benchmark_spec_payload, base_config_payload = _extract_benchmark_spec_payload( - job.config - ) - - timeout_seconds = job.timeout_seconds or self.default_job_timeout_seconds - - job_msg = EvalJobMessage( - job_id=job.id, - competition_id=job.competition_id, - submission_id=job.submission_id, - miner_hotkey=job.miner_hotkey, - hf_repo_id=job.hf_repo_id, - env_provider=job.env_provider, - benchmark_name=job.benchmark_name, - config=base_config_payload, - benchmark_spec=benchmark_spec_payload, - artifact_url=artifact_url, - artifact_expires_at=artifact_expires_at, - artifact_sha256=job.artifact_sha256, - artifact_size_bytes=job.artifact_size_bytes, - timeout=timedelta(seconds=timeout_seconds), - ) - - message = job_msg.model_dump_json() - broadcast_count = 0 - failed_connections = [] - - for conn_id, ws in list(self.active_connections.items()): - try: - await ws.send_text(message) - broadcast_count += 1 - except Exception as e: - logger.error(f"Failed to send to {conn_id}: {e}") - failed_connections.append(conn_id) - - # Clean up failed connections - for conn_id in failed_connections: - if conn_id in self.active_connections: - del self.active_connections[conn_id] - if conn_id in self.validator_connections: - del self.validator_connections[conn_id] - - # Update job status to QUEUED for all validators that received the job - if broadcast_count > 0 and self.async_session: - await self._update_job_status_for_validators( - job.id, EvaluationStatus.QUEUED, "Job queued to validators" - ) - - logger.info(f"Broadcasted job {job.id} to {broadcast_count} validators") - + # Helper methods for validator batch processing @staticmethod def _episode_score(message: EpisodeDataMessage) -> tuple[datetime, datetime, int]: return ( @@ -3013,7 +2062,6 @@ def _placeholder_episode_from_step( self, step_message: EpisodeStepDataMessage ) -> EpisodeDataMessage: """Create a minimal episode payload so steps can be stored before summary arrives.""" - info = step_message.info or {} success = bool(info.get("success")) start_end = _ensure_datetime(step_message.step_timestamp) diff --git a/src/backend/tests/__init__.py b/src/backend/tests/__init__.py new file mode 100644 index 0000000..0cedad7 --- /dev/null +++ b/src/backend/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for Kinitro backend components.""" diff --git a/src/backend/tests/test_scoring.py b/src/backend/tests/test_scoring.py new file mode 100644 index 0000000..1718cc6 --- /dev/null +++ b/src/backend/tests/test_scoring.py @@ -0,0 +1,193 @@ +"""Tests for ScoringEngine component.""" + +from unittest.mock import MagicMock + +import pytest + +from backend.models import ( + BackendEvaluationResult, + Competition, +) +from backend.scoring import ScoringConfig, ScoringEngine + + +class TestScoringConfig: + """Tests for ScoringConfig.""" + + def test_default_values(self): + """Test default configuration values.""" + config = ScoringConfig() + assert config.owner_uid == 4 + assert config.burn_pct == 0.98 + + def test_custom_values(self): + """Test custom configuration values.""" + config = ScoringConfig(owner_uid=10, burn_pct=0.5) + assert config.owner_uid == 10 + assert config.burn_pct == 0.5 + + def test_burn_pct_clamping_low(self): + """Test that burn_pct is clamped to [0, 1].""" + config = ScoringConfig(burn_pct=-0.5) + assert config.burn_pct == 0.0 + + def test_burn_pct_clamping_high(self): + """Test that burn_pct is clamped to [0, 1].""" + config = ScoringConfig(burn_pct=1.5) + assert config.burn_pct == 1.0 + + +class TestScoringEngineEligibility: + """Tests for ScoringEngine eligibility checks.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_factory = MagicMock() + self.config = ScoringConfig(owner_uid=4, burn_pct=0.98) + self.id_generator = iter(range(1000)) + self.engine = ScoringEngine( + session_factory=self.session_factory, + config=self.config, + id_generator=self.id_generator, + ) + + def test_eligible_result(self): + """Test that a result meeting all criteria is eligible.""" + result = MagicMock(spec=BackendEvaluationResult) + result.success_rate = 0.9 + result.avg_reward = 100.0 + result.miner_hotkey = "test_hotkey" + + competition = MagicMock(spec=Competition) + competition.id = "test_comp" + competition.task_type = "rl_rollout" + competition.min_success_rate = 0.8 + competition.min_avg_reward = 50.0 + + assert self.engine.is_eligible(result, competition) is True + + def test_ineligible_low_success_rate(self): + """Test that a result below success rate threshold is ineligible.""" + result = MagicMock(spec=BackendEvaluationResult) + result.success_rate = 0.5 + result.avg_reward = 100.0 + result.miner_hotkey = "test_hotkey" + + competition = MagicMock(spec=Competition) + competition.id = "test_comp" + competition.task_type = "rl_rollout" + competition.min_success_rate = 0.8 + competition.min_avg_reward = 50.0 + + assert self.engine.is_eligible(result, competition) is False + + def test_ineligible_low_avg_reward(self): + """Test that a result below avg reward threshold is ineligible.""" + result = MagicMock(spec=BackendEvaluationResult) + result.success_rate = 0.9 + result.avg_reward = 30.0 + result.miner_hotkey = "test_hotkey" + + competition = MagicMock(spec=Competition) + competition.id = "test_comp" + competition.task_type = "rl_rollout" + competition.min_success_rate = 0.8 + competition.min_avg_reward = 50.0 + + assert self.engine.is_eligible(result, competition) is False + + def test_ineligible_none_success_rate(self): + """Test that a result with None success rate is ineligible.""" + result = MagicMock(spec=BackendEvaluationResult) + result.success_rate = None + result.avg_reward = 100.0 + result.miner_hotkey = "test_hotkey" + + competition = MagicMock(spec=Competition) + competition.id = "test_comp" + competition.task_type = "rl_rollout" + competition.min_success_rate = 0.8 + competition.min_avg_reward = 50.0 + + assert self.engine.is_eligible(result, competition) is False + + def test_ineligible_none_avg_reward(self): + """Test that a result with None avg reward is ineligible.""" + result = MagicMock(spec=BackendEvaluationResult) + result.success_rate = 0.9 + result.avg_reward = None + result.miner_hotkey = "test_hotkey" + + competition = MagicMock(spec=Competition) + competition.id = "test_comp" + competition.task_type = "rl_rollout" + competition.min_success_rate = 0.8 + competition.min_avg_reward = 50.0 + + assert self.engine.is_eligible(result, competition) is False + + +class TestScoringEngineWeights: + """Tests for ScoringEngine weight computation.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_factory = MagicMock() + self.config = ScoringConfig( + owner_uid=4, burn_pct=0.0 + ) # No burn for easier testing + self.id_generator = iter(range(1000)) + self.engine = ScoringEngine( + session_factory=self.session_factory, + config=self.config, + id_generator=self.id_generator, + ) + + def test_compute_weights_single_miner(self): + """Test weight computation with a single miner.""" + miner_scores = {"hotkey1": 0.5} + + # Create mock node + node1 = MagicMock() + node1.node_id = 1 + + nodes = {"hotkey1": node1} + + weights = self.engine.compute_weights(miner_scores, nodes) + + assert weights[1] == 0.5 # Miner weight + assert weights[4] == 0.5 # Owner gets remainder + + def test_compute_weights_multiple_miners(self): + """Test weight computation with multiple miners.""" + miner_scores = {"hotkey1": 0.3, "hotkey2": 0.4} + + node1 = MagicMock() + node1.node_id = 1 + node2 = MagicMock() + node2.node_id = 2 + + nodes = {"hotkey1": node1, "hotkey2": node2} + + weights = self.engine.compute_weights(miner_scores, nodes) + + assert weights[1] == 0.3 + assert weights[2] == 0.4 + assert weights[4] == pytest.approx(0.3, rel=1e-6) # Owner gets remainder + + def test_compute_weights_fills_zeros(self): + """Test that unscored nodes get 0.0 weight.""" + miner_scores = {"hotkey1": 0.5} + + node1 = MagicMock() + node1.node_id = 1 + node2 = MagicMock() + node2.node_id = 2 + + nodes = {"hotkey1": node1, "hotkey2": node2} + + weights = self.engine.compute_weights(miner_scores, nodes) + + assert weights[1] == 0.5 + assert weights[2] == 0.0 + assert weights[4] == 0.5 # Owner gets remainder diff --git a/src/core/messages.py b/src/core/messages.py index 7891c6a..175f4d4 100644 --- a/src/core/messages.py +++ b/src/core/messages.py @@ -31,6 +31,11 @@ class MessageType(StrEnum): JOB_STATUS_UPDATE = "job_status_update" ERROR = "error" + # Evaluator-Backend WebSocket messages + EVALUATOR_REGISTER = "evaluator_register" + EVALUATOR_REGISTRATION_ACK = "evaluator_registration_ack" + JOB_ACK = "job_ack" + # Client-Backend WebSocket messages SUBSCRIBE = "subscribe" UNSUBSCRIBE = "unsubscribe" @@ -96,6 +101,11 @@ class EvalJobMessage(SQLModel): artifact_size_bytes: Optional[int] = None timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + # Task type for executor dispatch (defaults to rl_rollout for backward compatibility) + task_type: str = "rl_rollout" + # Optional serialized TaskSpec for new-style jobs + task_spec: Optional[dict] = None + @field_validator("timeout", mode="before") @classmethod def validate_timeout(cls, v: Any) -> Optional[timedelta]: @@ -254,6 +264,42 @@ class ErrorMessage(SQLModel): timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) +# ============================================================================ +# Evaluator-Backend WebSocket Messages +# ============================================================================ + + +class EvaluatorRegisterMessage(SQLModel): + """Message for evaluator registration with backend (direct connection).""" + + message_type: MessageType = MessageType.EVALUATOR_REGISTER + evaluator_id: str # Unique identifier for this evaluator instance + api_key: str # API key for authentication + capabilities: Optional[Dict[str, Any]] = None # GPU count, memory, supported tasks + max_concurrent_jobs: int = 1 + supported_task_types: List[str] = Field(default_factory=lambda: ["rl_rollout"]) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class EvaluatorRegistrationAckMessage(SQLModel): + """Acknowledgment message for evaluator registration.""" + + message_type: MessageType = MessageType.EVALUATOR_REGISTRATION_ACK + success: bool + error: Optional[str] = None + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class JobAckMessage(SQLModel): + """Acknowledgment from evaluator that it received/accepted a job.""" + + message_type: MessageType = MessageType.JOB_ACK + job_id: SnowflakeId + accepted: bool + reason: Optional[str] = None # Reason if not accepted + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + # ============================================================================ # Client-Backend WebSocket Messages # ============================================================================ diff --git a/src/core/tasks.py b/src/core/tasks.py new file mode 100644 index 0000000..2e4332e --- /dev/null +++ b/src/core/tasks.py @@ -0,0 +1,275 @@ +""" +Core task abstraction layer for Kinitro. + +This module defines the interfaces for task execution that decouple +evaluation logic from RL-specific assumptions. New task types can be +added by implementing the TaskExecutor protocol. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta +from enum import StrEnum +from typing import TYPE_CHECKING, Any, Optional, Protocol, runtime_checkable + +if TYPE_CHECKING: + pass + + +class TaskType(StrEnum): + """Enumeration of supported task types.""" + + RL_ROLLOUT = "rl_rollout" + # Future task types: + # TRAINING_RUN = "training_run" + + +@dataclass +class ResourceSpec: + """Resource requirements for task execution. + + Validators use this to schedule tasks on appropriate hardware + and manage resource allocation. + """ + + cpu_cores: float = 1.0 + memory_mb: int = 2048 + gpu_count: int = 0 + gpu_memory_mb: int = 0 + storage_mb: int = 1024 + + def __post_init__(self): + if self.cpu_cores < 0: + raise ValueError("cpu_cores must be non-negative") + if self.memory_mb < 0: + raise ValueError("memory_mb must be non-negative") + if self.gpu_count < 0: + raise ValueError("gpu_count must be non-negative") + + +@dataclass +class TaskSpec: + """Base specification for any evaluable work unit. + + TaskSpec is the universal contract between the backend (which creates jobs) + and validators (which execute them). All task-type-specific configuration + is stored in the `config` dict. + """ + + task_type: TaskType + task_id: str + config: dict[str, Any] + timeout: timedelta + resources: ResourceSpec + + # Execution context + submission_id: int + competition_id: str + miner_hotkey: str + + # Artifact information + artifact_url: str + artifact_sha256: Optional[str] = None + artifact_size_bytes: Optional[int] = None + artifact_expires_at: Optional[Any] = None # datetime + + # Optional metadata + job_id: Optional[int] = None + hf_repo_id: Optional[str] = None + env_provider: Optional[str] = None + benchmark_name: Optional[str] = None + + def to_dict(self) -> dict[str, Any]: + """Serialize TaskSpec to a dictionary for message passing.""" + return { + "task_type": self.task_type.value, + "task_id": self.task_id, + "config": self.config, + "timeout_seconds": self.timeout.total_seconds(), + "resources": { + "cpu_cores": self.resources.cpu_cores, + "memory_mb": self.resources.memory_mb, + "gpu_count": self.resources.gpu_count, + "gpu_memory_mb": self.resources.gpu_memory_mb, + "storage_mb": self.resources.storage_mb, + }, + "submission_id": self.submission_id, + "competition_id": self.competition_id, + "miner_hotkey": self.miner_hotkey, + "artifact_url": self.artifact_url, + "artifact_sha256": self.artifact_sha256, + "artifact_size_bytes": self.artifact_size_bytes, + "job_id": self.job_id, + "hf_repo_id": self.hf_repo_id, + "env_provider": self.env_provider, + "benchmark_name": self.benchmark_name, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TaskSpec": + """Deserialize TaskSpec from a dictionary.""" + resources = ResourceSpec(**data.get("resources", {})) + timeout_seconds = data.get("timeout_seconds", 3600) + + return cls( + task_type=TaskType(data["task_type"]), + task_id=data["task_id"], + config=data.get("config", {}), + timeout=timedelta(seconds=timeout_seconds), + resources=resources, + submission_id=data["submission_id"], + competition_id=data["competition_id"], + miner_hotkey=data["miner_hotkey"], + artifact_url=data["artifact_url"], + artifact_sha256=data.get("artifact_sha256"), + artifact_size_bytes=data.get("artifact_size_bytes"), + job_id=data.get("job_id"), + hf_repo_id=data.get("hf_repo_id"), + env_provider=data.get("env_provider"), + benchmark_name=data.get("benchmark_name"), + ) + + +@dataclass +class TaskResult: + """Result from task execution. + + TaskResult is the universal output format that all executors produce. + Task-type-specific metrics are stored in the `metrics` dict. + """ + + task_id: str + success: bool + metrics: dict[str, float] = field(default_factory=dict) + artifacts: dict[str, str] = field(default_factory=dict) # name -> S3 key + logs: Optional[str] = None + error: Optional[str] = None + duration_seconds: float = 0.0 + + # For RL tasks, we include episode-level details + total_episodes: Optional[int] = None + env_results: Optional[list[Any]] = None # List of EnvResult for RL tasks + + def to_dict(self) -> dict[str, Any]: + """Serialize TaskResult to a dictionary.""" + return { + "task_id": self.task_id, + "success": self.success, + "metrics": self.metrics, + "artifacts": self.artifacts, + "logs": self.logs, + "error": self.error, + "duration_seconds": self.duration_seconds, + "total_episodes": self.total_episodes, + } + + +@dataclass +class TaskContext: + """Execution context passed between setup/execute/teardown phases. + + This holds all the state needed for a task execution, including + references to resources that need cleanup. + """ + + spec: TaskSpec + work_dir: str + env_vars: dict[str, str] = field(default_factory=dict) + + # Container and infrastructure references + container_name: Optional[str] = None + container_host: Optional[str] = None + container_port: Optional[int] = None + + # Executor-specific state (e.g., RolloutCluster, RolloutWorker) + state: dict[str, Any] = field(default_factory=dict) + + # Timing + start_time: Optional[Any] = None # datetime + + +@runtime_checkable +class TaskExecutor(Protocol): + """Interface for executing different task types. + + Implement this protocol to add support for new task types. + The executor lifecycle is: + 1. validate_spec() - Check if the task spec is valid + 2. setup() - Prepare execution environment + 3. execute() - Run the task + 4. teardown() - Clean up resources + """ + + @property + def task_type(self) -> TaskType: + """The task type this executor handles.""" + ... + + async def validate_spec(self, spec: TaskSpec) -> list[str]: + """Validate a task specification. + + Args: + spec: The task specification to validate + + Returns: + List of validation error messages. Empty list means valid. + """ + ... + + async def setup(self, spec: TaskSpec) -> TaskContext: + """Prepare execution environment. + + This method should: + - Create containers/pods + - Initialize any required infrastructure + - Return a TaskContext with all state needed for execution + + Args: + spec: The task specification + + Returns: + TaskContext with execution state + """ + ... + + async def execute(self, context: TaskContext) -> TaskResult: + """Run the task. + + This is the main execution method. It should: + - Run the actual task logic + - Collect metrics and results + - Return a TaskResult + + Args: + context: The execution context from setup() + + Returns: + TaskResult with execution results + """ + ... + + async def teardown(self, context: TaskContext) -> None: + """Clean up resources. + + This method should release all resources allocated in setup(), + including containers, workers, etc. + + Args: + context: The execution context + """ + ... + + +class ExecutorNotFoundError(Exception): + """Raised when no executor is registered for a task type.""" + + pass + + +class TaskValidationError(Exception): + """Raised when task spec validation fails.""" + + def __init__(self, errors: list[str]): + self.errors = errors + super().__init__(f"Task validation failed: {'; '.join(errors)}") diff --git a/src/evaluator/__init__.py b/src/evaluator/__init__.py index 89cec32..80db095 100644 --- a/src/evaluator/__init__.py +++ b/src/evaluator/__init__.py @@ -6,7 +6,7 @@ - database: ``DatabaseManager`` - PostgreSQL database management """ -from validator.db.db_manager import DatabaseManager +from evaluator.db.db_manager import DatabaseManager from .agent_interface import AgentInterface from .rollout.envs import EnvManager, EnvSpec diff --git a/src/validator/alembic.ini b/src/evaluator/alembic.ini similarity index 96% rename from src/validator/alembic.ini rename to src/evaluator/alembic.ini index 95008d4..36a7bc5 100644 --- a/src/validator/alembic.ini +++ b/src/evaluator/alembic.ini @@ -43,7 +43,7 @@ version_path_separator = os output_encoding = utf-8 # Database URL - will be overridden by env.py -sqlalchemy.url = postgresql://postgres@localhost/validator_db +sqlalchemy.url = postgresql://postgres@localhost/evaluator_db [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run diff --git a/src/evaluator/backend_client.py b/src/evaluator/backend_client.py new file mode 100644 index 0000000..0bb50ff --- /dev/null +++ b/src/evaluator/backend_client.py @@ -0,0 +1,478 @@ +""" +WebSocket client for evaluator to communicate directly with backend. + +This module provides the BackendClient class that enables evaluators to +connect directly to the backend without going through the validator relay. +""" + +import asyncio +import json +import os +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable, Dict, List, Optional + +import websockets +from websockets.exceptions import ConnectionClosed, WebSocketException + +from core.log import get_logger +from core.messages import ( + EpisodeDataMessage, + EpisodeStepDataMessage, + EvalJobMessage, + EvalResultMessage, + EvaluatorRegisterMessage, + HeartbeatMessage, + JobAckMessage, + JobStatusUpdateMessage, + MessageType, +) + +logger = get_logger(__name__) + +SEND_QUEUE_MAXSIZE = 1000 +SEND_QUEUE_WARN_FRACTION = 0.8 + + +class BackendClient: + """ + WebSocket client for evaluator to communicate with backend. + + This enables direct communication between evaluators and the backend, + eliminating the need for a validator relay. + + Features: + - Automatic reconnection with exponential backoff + - Send queue with backpressure handling + - Heartbeat keepalive + - Job acknowledgment support + """ + + def __init__( + self, + backend_url: str, + evaluator_id: str, + api_key: Optional[str] = None, + supported_task_types: Optional[List[str]] = None, + max_concurrent_jobs: int = 1, + capabilities: Optional[Dict[str, Any]] = None, + on_job_received: Optional[Callable[[EvalJobMessage], Awaitable[None]]] = None, + reconnect_interval: float = 5.0, + max_reconnect_interval: float = 60.0, + heartbeat_interval: float = 30.0, + ): + """ + Initialize the backend client. + + Args: + backend_url: WebSocket URL for backend connection (e.g., ws://backend:8080/ws/evaluator) + evaluator_id: Unique identifier for this evaluator instance + api_key: API key for authentication (defaults to KINITRO_API_KEY env var) + supported_task_types: List of task types this evaluator supports + max_concurrent_jobs: Maximum concurrent jobs this evaluator can handle + capabilities: Additional metadata about evaluator capabilities (GPU, memory, etc.) + on_job_received: Callback for when a job is received + reconnect_interval: Initial reconnection interval in seconds + max_reconnect_interval: Maximum reconnection interval in seconds + heartbeat_interval: Interval between heartbeat messages in seconds + """ + self.backend_url = backend_url + self.evaluator_id = evaluator_id + self.api_key = api_key or os.environ.get("KINITRO_API_KEY") + if not self.api_key: + raise ValueError( + "API key not provided. Set KINITRO_API_KEY environment variable " + "or pass api_key parameter" + ) + + self.supported_task_types = supported_task_types or ["rl_rollout"] + self.max_concurrent_jobs = max_concurrent_jobs + self.capabilities = capabilities + + # Callback + self.on_job_received = on_job_received + + # Connection settings + self.reconnect_interval = reconnect_interval + self.max_reconnect_interval = max_reconnect_interval + self.heartbeat_interval = heartbeat_interval + + # Connection state + self.websocket: Optional[websockets.WebSocketClientProtocol] = None + self.connected = False + self._running = False + self._heartbeat_task: Optional[asyncio.Task] = None + self._sender_task: Optional[asyncio.Task] = None + self._send_queue: Optional[asyncio.Queue[Optional[dict]]] = None + + # Reconnect backoff state + self._current_reconnect_interval = reconnect_interval + + logger.info( + f"BackendClient initialized for evaluator {evaluator_id}, " + f"connecting to {backend_url}" + ) + + async def connect_and_run(self) -> None: + """ + Connect to backend and run the message loop. + + This method will automatically reconnect on connection loss. + It runs until stop() is called. + """ + logger.info(f"Starting BackendClient for evaluator {self.evaluator_id}") + self._running = True + self._current_reconnect_interval = self.reconnect_interval + + while self._running: + try: + await self._connect_to_backend() + + # Connection lost, wait before retry with backoff + if self._running: + logger.warning( + f"Connection lost, reconnecting in {self._current_reconnect_interval:.1f}s" + ) + await asyncio.sleep(self._current_reconnect_interval) + # Exponential backoff + self._current_reconnect_interval = min( + self._current_reconnect_interval * 1.5, + self.max_reconnect_interval, + ) + except Exception as e: + logger.error(f"Failed to connect to backend: {e}") + if self._running: + await asyncio.sleep(self._current_reconnect_interval) + self._current_reconnect_interval = min( + self._current_reconnect_interval * 1.5, + self.max_reconnect_interval, + ) + + async def stop(self) -> None: + """Stop the client and close the connection.""" + logger.info(f"Stopping BackendClient for evaluator {self.evaluator_id}") + self._running = False + + # Cancel heartbeat + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + # Signal sender to stop + if self._send_queue: + try: + self._send_queue.put_nowait(None) # Sentinel to stop sender + except asyncio.QueueFull: + pass + + # Cancel sender + if self._sender_task: + self._sender_task.cancel() + try: + await self._sender_task + except asyncio.CancelledError: + pass + self._sender_task = None + + self._send_queue = None + + # Close WebSocket + if self.websocket: + try: + await self.websocket.close() + except Exception as e: + logger.debug(f"Error closing WebSocket: {e}") + self.websocket = None + self.connected = False + + logger.info(f"BackendClient stopped for evaluator {self.evaluator_id}") + + async def send_result(self, result: EvalResultMessage) -> bool: + """ + Send an evaluation result to the backend. + + Args: + result: The evaluation result message + + Returns: + True if queued successfully + """ + # Use mode='json' to properly serialize enums to their values + return await self._queue_message(result.model_dump(mode="json")) + + async def send_status_update(self, status: JobStatusUpdateMessage) -> bool: + """ + Send a job status update to the backend. + + Args: + status: The status update message + + Returns: + True if queued successfully + """ + # Use mode='json' to properly serialize enums to their values + return await self._queue_message(status.model_dump(mode="json")) + + async def send_episode_data(self, data: EpisodeDataMessage) -> bool: + """ + Send episode telemetry data to the backend. + + Args: + data: The episode data message + + Returns: + True if queued successfully + """ + # Use mode='json' to properly serialize enums to their values + return await self._queue_message(data.model_dump(mode="json")) + + async def send_episode_step_data(self, data: EpisodeStepDataMessage) -> bool: + """ + Send episode step data to the backend. + + Args: + data: The episode step data message + + Returns: + True if queued successfully + """ + # Use mode='json' to properly serialize enums to their values + return await self._queue_message(data.model_dump(mode="json")) + + async def send_job_ack( + self, job_id: int, accepted: bool, reason: Optional[str] = None + ) -> bool: + """ + Send a job acknowledgment to the backend. + + Args: + job_id: The job ID being acknowledged + accepted: Whether the job was accepted + reason: Optional reason if not accepted + + Returns: + True if queued successfully + """ + ack = JobAckMessage( + job_id=job_id, + accepted=accepted, + reason=reason, + ) + return await self._queue_message(ack.model_dump()) + + async def _connect_to_backend(self) -> None: + """Connect to backend and handle messages.""" + try: + logger.info(f"Connecting to backend: {self.backend_url}") + + async with websockets.connect( + self.backend_url, + ping_interval=None, # We use application-level heartbeat + ping_timeout=None, + close_timeout=10, + ) as websocket: + self.websocket = websocket + self._send_queue = asyncio.Queue(maxsize=SEND_QUEUE_MAXSIZE) + self._sender_task = asyncio.create_task(self._sender_loop()) + + # Register with backend + await self._register() + + # Start heartbeat + self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) + + # Reset reconnect interval on successful connection + self._current_reconnect_interval = self.reconnect_interval + + # Handle messages + await self._message_loop() + + except ConnectionClosed: + logger.warning("Backend connection closed") + except WebSocketException as e: + logger.error(f"WebSocket error: {e}") + except Exception as e: + logger.error(f"Connection error: {e}") + finally: + self.connected = False + if self._heartbeat_task: + self._heartbeat_task.cancel() + try: + await self._heartbeat_task + except asyncio.CancelledError: + pass + self._heartbeat_task = None + + async def _register(self) -> None: + """Register with the backend.""" + if not self.websocket: + raise RuntimeError("Not connected to backend") + + register_msg = EvaluatorRegisterMessage( + evaluator_id=self.evaluator_id, + api_key=self.api_key, + supported_task_types=self.supported_task_types, + max_concurrent_jobs=self.max_concurrent_jobs, + capabilities=self.capabilities, + ) + + await self.websocket.send(register_msg.model_dump_json()) + + # Wait for acknowledgment + response = await self.websocket.recv() + data = json.loads(response) + + if data.get("message_type") == MessageType.EVALUATOR_REGISTRATION_ACK: + if data.get("success"): + self.connected = True + logger.info( + f"Evaluator {self.evaluator_id} registered successfully with backend" + ) + else: + error = data.get("error", "Unknown error") + raise RuntimeError(f"Registration failed: {error}") + else: + raise RuntimeError(f"Unexpected registration response: {data}") + + async def _heartbeat_loop(self) -> None: + """Send periodic heartbeat messages.""" + try: + while self._running and self.connected: + await asyncio.sleep(self.heartbeat_interval) + + if self.websocket and self.connected: + heartbeat = HeartbeatMessage() + await self._queue_message(heartbeat.model_dump()) + logger.debug(f"Sent heartbeat from evaluator {self.evaluator_id}") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error in heartbeat loop: {e}") + + async def _sender_loop(self) -> None: + """Background task to send queued messages.""" + try: + while self._running: + if not self._send_queue: + break + + message = await self._send_queue.get() + + # Sentinel value to stop + if message is None: + break + + if self.websocket and self.connected: + try: + await self.websocket.send(json.dumps(message, default=str)) + except Exception as e: + logger.error(f"Error sending message: {e}") + # Re-queue the message if still running + if self._running and self._send_queue: + try: + self._send_queue.put_nowait(message) + except asyncio.QueueFull: + logger.warning("Send queue full, dropping message") + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"Error in sender loop: {e}") + + async def _message_loop(self) -> None: + """Handle incoming messages from backend.""" + if not self.websocket: + return + + try: + async for message in self.websocket: + data = json.loads(message) + message_type = data.get("message_type") + + if message_type == MessageType.HEARTBEAT_ACK: + logger.debug("Received heartbeat ack from backend") + + elif message_type == MessageType.EVAL_JOB: + # Received a new job + job_msg = EvalJobMessage(**data) + logger.info( + f"Received job {job_msg.job_id} from backend for " + f"competition {job_msg.competition_id}" + ) + + # Call the job handler if set + if self.on_job_received: + try: + await self.on_job_received(job_msg) + # Acknowledge job acceptance + await self.send_job_ack(job_msg.job_id, accepted=True) + except Exception as e: + logger.error(f"Error handling job {job_msg.job_id}: {e}") + await self.send_job_ack( + job_msg.job_id, + accepted=False, + reason=str(e), + ) + else: + logger.warning( + f"Received job {job_msg.job_id} but no handler registered" + ) + await self.send_job_ack( + job_msg.job_id, + accepted=False, + reason="No job handler registered", + ) + + elif message_type == MessageType.RESULT_ACK: + job_id = data.get("job_id") + logger.debug(f"Backend acknowledged result for job {job_id}") + + elif message_type == MessageType.ERROR: + error = data.get("error", "Unknown error") + details = data.get("details") + logger.error(f"Received error from backend: {error} - {details}") + + else: + logger.debug(f"Received unknown message type: {message_type}") + + except ConnectionClosed: + logger.warning("Connection closed during message loop") + raise + except Exception as e: + logger.error(f"Error in message loop: {e}") + raise + + async def _queue_message(self, message: dict) -> bool: + """ + Queue a message for sending. + + Args: + message: The message to queue + + Returns: + True if queued successfully + """ + if not self._send_queue: + logger.warning("Send queue not initialized, dropping message") + return False + + try: + # Check queue health + queue_size = self._send_queue.qsize() + if queue_size > SEND_QUEUE_MAXSIZE * SEND_QUEUE_WARN_FRACTION: + logger.warning(f"Send queue is {queue_size}/{SEND_QUEUE_MAXSIZE} full") + + self._send_queue.put_nowait(message) + return True + except asyncio.QueueFull: + logger.error("Send queue full, dropping message") + return False + + @property + def is_connected(self) -> bool: + """Check if the client is connected to the backend.""" + return self.connected and self.websocket is not None diff --git a/src/evaluator/config.py b/src/evaluator/config.py index ab02652..95cd214 100644 --- a/src/evaluator/config.py +++ b/src/evaluator/config.py @@ -1,3 +1,6 @@ +import os +import uuid + import dotenv from core.config import Config, ConfigOpts @@ -9,6 +12,12 @@ DEFAULT_RPC_HANDSHAKE_MAX_ATTEMPTS = 5 DEFAULT_RPC_HANDSHAKE_RETRY_SECONDS = 2.0 +# Backend connection defaults +DEFAULT_BACKEND_WS_URL = "ws://localhost:8080/ws/evaluator" +DEFAULT_RECONNECT_INTERVAL = 5.0 +DEFAULT_MAX_RECONNECT_INTERVAL = 60.0 +DEFAULT_HEARTBEAT_INTERVAL = 30.0 + class EvaluatorConfig(Config): def __init__(self): @@ -21,6 +30,28 @@ def __init__(self): self.pg_database = self.settings.get("pg_database") # type: ignore self.log_file = self._normalize_log_file(self.settings.get("log_file")) + # Backend WebSocket connection settings + self.backend_ws_url = self.settings.get( + "backend_ws_url", DEFAULT_BACKEND_WS_URL + ) + self.evaluator_id = self.settings.get( + "evaluator_id", + os.environ.get("EVALUATOR_ID", f"evaluator-{uuid.uuid4().hex[:8]}"), + ) + self.api_key = os.environ.get("KINITRO_API_KEY") + self.reconnect_interval = float( + self.settings.get("reconnect_interval", DEFAULT_RECONNECT_INTERVAL) + ) + self.max_reconnect_interval = float( + self.settings.get("max_reconnect_interval", DEFAULT_MAX_RECONNECT_INTERVAL) + ) + self.heartbeat_interval = float( + self.settings.get("heartbeat_interval", DEFAULT_HEARTBEAT_INTERVAL) + ) + + # Connection mode: "direct" for WebSocket to backend + self.connection_mode = self.settings.get("connection_mode", "direct") + # S3 storage configuration self.s3_config = load_s3_config() diff --git a/src/evaluator/db/__init__.py b/src/evaluator/db/__init__.py new file mode 100644 index 0000000..1163599 --- /dev/null +++ b/src/evaluator/db/__init__.py @@ -0,0 +1,6 @@ +"""Evaluator database models and manager.""" + +from .db_manager import DatabaseManager +from .models import EvaluationJob, EvaluationResult, EvaluationStatus + +__all__ = ["DatabaseManager", "EvaluationJob", "EvaluationResult", "EvaluationStatus"] diff --git a/src/validator/db/db_manager.py b/src/evaluator/db/db_manager.py similarity index 82% rename from src/validator/db/db_manager.py rename to src/evaluator/db/db_manager.py index a15e5db..042f08a 100644 --- a/src/validator/db/db_manager.py +++ b/src/evaluator/db/db_manager.py @@ -1,15 +1,11 @@ from contextlib import contextmanager from typing import Any, Dict, List, Optional -import asyncpg -from pgqueuer import Queries -from pgqueuer.db import AsyncpgDriver from snowflake import SnowflakeGenerator from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from core.db.models import SnowflakeId -from core.messages import EvalResultMessage, JobStatusUpdateMessage from .models import ( EvaluationJob, @@ -168,24 +164,3 @@ def update_evaluation_result( session.flush() session.refresh(pg_result) return EvaluationResult.model_validate(pg_result) - - async def queue_evaluation_result_msg(self, eval_result: EvalResultMessage) -> None: - """Queue an evaluation result message for processing.""" - # TODO: there is probably a better way to do this - conn = await asyncpg.connect(dsn=self.postgres_url) - driver = AsyncpgDriver(conn) - q = Queries(driver) - eval_result_bytes = eval_result.model_dump_json().encode("utf-8") - await q.enqueue(["eval_result"], [eval_result_bytes], [0]) - await conn.close() - - async def queue_job_status_update_msg( - self, job_status: JobStatusUpdateMessage - ) -> None: - """Queue a job status update message for processing.""" - conn = await asyncpg.connect(dsn=self.postgres_url) - driver = AsyncpgDriver(conn) - q = Queries(driver) - status_bytes = job_status.model_dump_json().encode("utf-8") - await q.enqueue(["job_status_update"], [status_bytes], [0]) - await conn.close() diff --git a/src/validator/db/models.py b/src/evaluator/db/models.py similarity index 100% rename from src/validator/db/models.py rename to src/evaluator/db/models.py diff --git a/src/evaluator/executors/README.md b/src/evaluator/executors/README.md new file mode 100644 index 0000000..1a92636 --- /dev/null +++ b/src/evaluator/executors/README.md @@ -0,0 +1,167 @@ +# Task Executors + +This package contains implementations of `TaskExecutor` for different task types in Kinitro. + +## Architecture + +The executor pattern decouples task execution logic from the orchestrator, allowing new task types to be added without modifying the core evaluation infrastructure. + +``` +Orchestrator + │ + ▼ +ExecutorRegistry.get(task_type) + │ + ▼ +TaskExecutor (interface) + │ + ├── RLRolloutExecutor (rl_rollout) + └── TrainingExecutor (training_run) [future] + ... +``` + +## Creating a New Executor + +### 1. Define Your Task Type + +Add your task type to `src/core/tasks.py`: + +```python +class TaskType(StrEnum): + RL_ROLLOUT = "rl_rollout" + YOUR_NEW_TYPE = "your_new_type" # Add this +``` + +### 2. Implement TaskExecutor + +Create a new file in `src/evaluator/executors/`: + +```python +# src/evaluator/executors/your_executor.py + +from core.tasks import TaskContext, TaskExecutor, TaskResult, TaskSpec, TaskType + +class YourExecutor: + """Executor for your task type.""" + + task_type = TaskType.YOUR_NEW_TYPE + + def __init__(self, config): + self.config = config + + async def validate_spec(self, spec: TaskSpec) -> list[str]: + """Return validation errors, or empty list if valid.""" + errors = [] + if not spec.config.get("required_field"): + errors.append("required_field is missing") + return errors + + async def setup(self, spec: TaskSpec) -> TaskContext: + """Prepare execution environment.""" + # Create containers, initialize resources + context = TaskContext( + spec=spec, + work_dir="/tmp/your-task", + state={"your_state": "here"}, + ) + return context + + async def execute(self, context: TaskContext) -> TaskResult: + """Run the task.""" + try: + # Your execution logic + metrics = {"score": 0.95} + return TaskResult( + task_id=context.spec.task_id, + success=True, + metrics=metrics, + ) + except Exception as e: + return TaskResult( + task_id=context.spec.task_id, + success=False, + error=str(e), + ) + + async def teardown(self, context: TaskContext) -> None: + """Clean up resources.""" + # Release containers, close connections, etc. + pass +``` + +### 3. Register Your Executor + +Register in the orchestrator's `_register_default_executors()`: + +```python +def _register_default_executors(self) -> None: + ExecutorRegistry.register(RLRolloutExecutor(self.config)) + ExecutorRegistry.register(YourExecutor(self.config)) # Add this +``` + +Or register dynamically: + +```python +from evaluator.executors import ExecutorRegistry +from evaluator.executors.your_executor import YourExecutor + +ExecutorRegistry.register(YourExecutor(config)) +``` + +### 4. Update Competition Model + +If needed, add competition support for your task type. Competitions use the `task_type` field to determine which executor handles evaluations. + +## TaskSpec Configuration + +The `TaskSpec.config` dict contains task-type-specific configuration. For RL rollouts, this includes: + +- `env_provider`: Environment provider name (e.g., "metaworld", "swarm") +- `benchmark_name`: Benchmark identifier (e.g., "MT1", "MT10") +- `config`: Nested benchmark configuration + +Your executor can define its own configuration schema. + +## TaskResult Metrics + +The `TaskResult.metrics` dict should contain numerical metrics that can be used for scoring: + +```python +# RL rollout metrics +{ + "success_rate": 0.85, + "avg_reward": 1500.0, + "total_episodes": 100, +} + +# Your custom metrics +{ + "accuracy": 0.95, + "latency_ms": 150.0, + "custom_score": 42.0, +} +``` + +## Testing + +Add tests in `src/evaluator/executors/tests/`: + +```python +# test_your_executor.py + +import pytest +from evaluator.executors.your_executor import YourExecutor + +class TestYourExecutor: + def test_validate_spec_valid(self): + # ... + + async def test_execute_success(self): + # ... +``` + +## Files + +- `registry.py` - ExecutorRegistry for task type dispatch +- `rl_rollout.py` - RLRolloutExecutor for RL evaluation tasks +- `tests/` - Unit tests diff --git a/src/evaluator/executors/__init__.py b/src/evaluator/executors/__init__.py new file mode 100644 index 0000000..872141a --- /dev/null +++ b/src/evaluator/executors/__init__.py @@ -0,0 +1,10 @@ +""" +Task executors package. + +This package contains implementations of TaskExecutor for different task types. +""" + +from .registry import ExecutorRegistry +from .rl_rollout import RLRolloutExecutor + +__all__ = ["ExecutorRegistry", "RLRolloutExecutor"] diff --git a/src/evaluator/executors/registry.py b/src/evaluator/executors/registry.py new file mode 100644 index 0000000..074ff9e --- /dev/null +++ b/src/evaluator/executors/registry.py @@ -0,0 +1,134 @@ +""" +Executor registry for task type dispatch. + +The registry maps TaskType values to TaskExecutor implementations, +allowing the Orchestrator to dispatch jobs to the appropriate executor. +""" + +from typing import Dict, Optional + +from core.log import get_logger +from core.tasks import ExecutorNotFoundError, TaskExecutor, TaskType + +logger = get_logger(__name__) + + +class ExecutorRegistry: + """Registry for TaskExecutor implementations. + + This class provides a central point for registering and retrieving + executors based on task type. It follows a class-level registry pattern + for global access. + + Usage: + # Register an executor + ExecutorRegistry.register(RLRolloutExecutor()) + + # Get an executor for a task type + executor = ExecutorRegistry.get(TaskType.RL_ROLLOUT) + + # Check if an executor exists + if ExecutorRegistry.has(TaskType.RL_ROLLOUT): + ... + + # List all registered task types + types = ExecutorRegistry.list_types() + """ + + _executors: Dict[TaskType, TaskExecutor] = {} + + @classmethod + def register(cls, executor: TaskExecutor) -> None: + """Register an executor for its task type. + + If an executor is already registered for the task type, it will be + replaced with a warning. + + Args: + executor: The TaskExecutor to register + """ + task_type = executor.task_type + if task_type in cls._executors: + logger.warning( + "Overwriting existing executor for task type %s", task_type.value + ) + cls._executors[task_type] = executor + logger.info("Registered executor for task type %s", task_type.value) + + @classmethod + def get(cls, task_type: TaskType) -> TaskExecutor: + """Get the executor for a task type. + + Args: + task_type: The task type to get an executor for + + Returns: + The registered TaskExecutor + + Raises: + ExecutorNotFoundError: If no executor is registered for the task type + """ + executor = cls._executors.get(task_type) + if executor is None: + raise ExecutorNotFoundError( + f"No executor registered for task type: {task_type.value}" + ) + return executor + + @classmethod + def get_optional(cls, task_type: TaskType) -> Optional[TaskExecutor]: + """Get the executor for a task type, or None if not registered. + + Args: + task_type: The task type to get an executor for + + Returns: + The registered TaskExecutor, or None + """ + return cls._executors.get(task_type) + + @classmethod + def has(cls, task_type: TaskType) -> bool: + """Check if an executor is registered for a task type. + + Args: + task_type: The task type to check + + Returns: + True if an executor is registered + """ + return task_type in cls._executors + + @classmethod + def unregister(cls, task_type: TaskType) -> bool: + """Unregister the executor for a task type. + + Args: + task_type: The task type to unregister + + Returns: + True if an executor was unregistered, False if none was registered + """ + if task_type in cls._executors: + del cls._executors[task_type] + logger.info("Unregistered executor for task type %s", task_type.value) + return True + return False + + @classmethod + def list_types(cls) -> list[TaskType]: + """List all registered task types. + + Returns: + List of registered TaskType values + """ + return list(cls._executors.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registered executors. + + Primarily useful for testing. + """ + cls._executors.clear() + logger.info("Cleared all registered executors") diff --git a/src/evaluator/executors/rl_rollout.py b/src/evaluator/executors/rl_rollout.py new file mode 100644 index 0000000..fd2df05 --- /dev/null +++ b/src/evaluator/executors/rl_rollout.py @@ -0,0 +1,574 @@ +""" +RL Rollout task executor. + +This executor wraps the existing RolloutWorker and RolloutCluster +infrastructure to execute RL rollout tasks through the TaskExecutor interface. +""" + +from __future__ import annotations + +import asyncio +import gc +import threading +import time +from datetime import datetime, timezone +from typing import Any, List, Optional + +import ray +from kubernetes import client, config +from ray.util.queue import Queue + +from core.db.models import SnowflakeId +from core.log import get_logger +from core.tasks import ( + TaskContext, + TaskResult, + TaskSpec, + TaskType, +) +from evaluator.config import EvaluatorConfig +from evaluator.constants import ( + PROCESS_JOB_WAIT_TIME, + QUEUE_MAXSIZE, + RAY_WAIT_TIMEOUT, + WAIT_TIME, +) +from evaluator.containers import Containers +from evaluator.providers.registry import ProviderRegistry +from evaluator.rollout import BenchmarkSpec, EnvManager, RolloutCluster +from evaluator.rollout.envs import EnvResult +from evaluator.rpc.rpc_process import RPCProcess + +logger = get_logger(__name__) + + +class RLRolloutExecutor: + """Executor for RL rollout tasks. + + This executor wraps the existing RolloutWorker/RolloutCluster infrastructure + to provide RL evaluation through the TaskExecutor interface. + + The executor handles: + - Creating Kubernetes pods for miner submissions + - Setting up Ray workers for parallel evaluation + - Running rollout episodes across multiple environments + - Collecting and aggregating results + """ + + task_type = TaskType.RL_ROLLOUT + + def __init__( + self, + evaluator_config: EvaluatorConfig, + provider_registry: Optional[ProviderRegistry] = None, + ): + """Initialize the RL rollout executor. + + Args: + evaluator_config: Configuration for the evaluator + provider_registry: Optional provider registry (uses global if not provided) + """ + self.config = evaluator_config + self.provider_registry = provider_registry or ProviderRegistry + self.env_manager = EnvManager() + + # Default timeouts and settings + self.rpc_handshake_max_attempts = getattr( + evaluator_config, "rpc_handshake_max_attempts", 5 + ) + self.rpc_handshake_retry_seconds = getattr( + evaluator_config, "rpc_handshake_retry_seconds", 2.0 + ) + + async def validate_spec(self, spec: TaskSpec) -> list[str]: + """Validate an RL rollout task specification. + + Checks: + - Required fields are present (artifact_url, env_provider, benchmark_name) + - Provider is registered + - Benchmark configuration is valid + + Args: + spec: The task specification to validate + + Returns: + List of validation error messages. Empty list means valid. + """ + errors: list[str] = [] + + # Check required fields + if not spec.artifact_url: + errors.append("artifact_url is required for RL rollout tasks") + + env_provider = spec.env_provider or spec.config.get("env_provider") + if not env_provider: + errors.append("env_provider is required for RL rollout tasks") + + benchmark_name = spec.benchmark_name or spec.config.get("benchmark_name") + if not benchmark_name: + errors.append("benchmark_name is required for RL rollout tasks") + + # Check provider is registered + if env_provider: + if not self.provider_registry.has_provider(env_provider): + # Not a hard error - EnvManager handles provider dispatch internally + logger.debug( + "Provider %s not found in registry, will use EnvManager dispatch", + env_provider, + ) + + # Validate benchmark config structure + config_payload = spec.config.get("config", spec.config) + if not isinstance(config_payload, dict): + errors.append("config must be a dictionary") + + return errors + + async def setup(self, spec: TaskSpec) -> TaskContext: + """Set up the execution environment for an RL rollout. + + This method: + 1. Creates a Kubernetes pod for the miner's submission + 2. Waits for the pod to be ready + 3. Creates a RolloutCluster and RolloutWorker + 4. Establishes RPC communication with the submission container + 5. Returns a TaskContext with all required state + + Args: + spec: The task specification + + Returns: + TaskContext with execution state + + Raises: + RuntimeError: If setup fails + """ + context = TaskContext( + spec=spec, + work_dir="/tmp", # TODO: configure per-job work directory + start_time=datetime.now(timezone.utc), + ) + + submission_id = spec.submission_id + job_id = spec.job_id or spec.task_id + + # Create container - run in thread pool to avoid blocking the event loop + # This is critical to prevent WebSocket keepalive timeouts during pod creation + containers = Containers() + loop = asyncio.get_event_loop() + try: + logger.info( + "Creating container for submission %s (running in thread pool)", + submission_id, + ) + pod_name = await loop.run_in_executor( + None, # Use default thread pool executor + lambda: containers.create_container( + submission_id, + job_id, + archive_url=spec.artifact_url, + archive_sha256=spec.artifact_sha256, + ), + ) + context.container_name = pod_name + context.state["containers"] = containers + context.state["container_ready"] = True + logger.info("Created pod: %s", pod_name) + except Exception as e: + logger.error("Failed to create container: %s", e) + raise RuntimeError(f"Failed to create container: {e}") from e + + # Get NodePort and Node IP for direct TCP connection + # Also run in thread pool to avoid blocking + try: + + def get_container_network_info(pod_name: str): + config.load_kube_config() + k8v1api = client.CoreV1Api() + v1 = client.CoreV1Api() + service_name = pod_name + svc = k8v1api.read_namespaced_service(service_name, "default") + node_port = None + for port in svc.spec.ports: + if port.node_port: + node_port = port.node_port + break + if not node_port: + raise RuntimeError(f"No nodePort found for service {service_name}") + + # Get the first node's external IP (or internal if not available) + nodes = v1.list_node().items + node_ip = None + for node in nodes: + for addr in node.status.addresses: + if addr.type == "ExternalIP": + node_ip = addr.address + break + if not node_ip: + for addr in node.status.addresses: + if addr.type == "InternalIP": + node_ip = addr.address + break + if node_ip: + break + if not node_ip: + raise RuntimeError("No node IP found in cluster") + return node_ip, node_port + + node_ip, node_port = await loop.run_in_executor( + None, lambda: get_container_network_info(pod_name) + ) + context.container_host = node_ip + context.container_port = node_port + except Exception as e: + logger.error("Failed to get container network info: %s", e) + await self._cleanup_on_setup_failure(context) + raise RuntimeError(f"Failed to get container network info: {e}") from e + + # Wait for container to be ready + await asyncio.sleep(WAIT_TIME.total_seconds()) + + # Build benchmark spec from task config + benchmark_spec = self._build_benchmark_spec(spec) + + # Create Ray queues + worker_to_rpc_queue = Queue(maxsize=QUEUE_MAXSIZE) + rpc_to_worker_queue = Queue(maxsize=QUEUE_MAXSIZE) + context.state["worker_to_rpc_queue"] = worker_to_rpc_queue + context.state["rpc_to_worker_queue"] = rpc_to_worker_queue + + # Create rollout cluster and worker + try: + cluster = RolloutCluster( + "eval-cluster", + worker_remote_options=self.config.worker_remote_options, + ) + worker = cluster.create_worker( + SnowflakeId(job_id) if isinstance(job_id, int) else job_id, + [benchmark_spec], + node_ip, + node_port, + SnowflakeId(submission_id) + if isinstance(submission_id, int) + else submission_id, + s3_config=self.config.s3_config, + episode_log_interval=self.config.episode_log_interval, + step_log_interval=self.config.step_log_interval, + ) + context.state["cluster"] = cluster + context.state["worker"] = worker + except Exception as e: + logger.error("Failed to create rollout worker: %s", e) + await self._cleanup_on_setup_failure(context) + raise RuntimeError(f"Failed to create rollout worker: {e}") from e + + # Start RPC thread + rpc_thread = threading.Thread( + target=RPCProcess, + args=(node_ip, node_port, rpc_to_worker_queue, worker_to_rpc_queue), + daemon=True, + ) + rpc_thread.start() + context.state["rpc_thread"] = rpc_thread + + await asyncio.sleep(PROCESS_JOB_WAIT_TIME.total_seconds()) + + # Wait for RPC handshake + try: + await self._wait_for_rpc_handshake( + job_id=job_id, + worker=worker, + worker_to_rpc_queue=worker_to_rpc_queue, + rpc_to_worker_queue=rpc_to_worker_queue, + ) + except Exception as e: + logger.error("RPC handshake failed: %s", e) + await self._cleanup_on_setup_failure(context) + raise RuntimeError(f"RPC handshake failed: {e}") from e + + logger.info("Setup complete for task %s", spec.task_id) + return context + + async def execute(self, context: TaskContext) -> TaskResult: + """Execute the RL rollout task. + + This runs all benchmark tasks and collects results. + + Args: + context: The execution context from setup() + + Returns: + TaskResult with evaluation metrics + """ + spec = context.spec + worker = context.state.get("worker") + worker_to_rpc_queue = context.state.get("worker_to_rpc_queue") + rpc_to_worker_queue = context.state.get("rpc_to_worker_queue") + + # Check for None explicitly - Ray Queue's bool() returns False when empty + if worker is None or worker_to_rpc_queue is None or rpc_to_worker_queue is None: + return TaskResult( + task_id=spec.task_id, + success=False, + error="Worker or queues not initialized", + ) + + start_time = time.time() + + try: + # Start the evaluation + evaluation_future = worker.run_all_benchmark_tasks.remote( + worker_to_rpc_queue, rpc_to_worker_queue + ) + + # Wait for completion with timeout + timeout_seconds = spec.timeout.total_seconds() + elapsed = 0.0 + + while elapsed < timeout_seconds: + ready, _ = ray.wait( + [evaluation_future], timeout=RAY_WAIT_TIMEOUT.total_seconds() + ) + + if ready: + results: List[EnvResult] = ray.get(evaluation_future) + duration = time.time() - start_time + + return self._build_success_result( + task_id=spec.task_id, + results=results, + duration=duration, + ) + + elapsed = time.time() - start_time + + # Timeout + logger.error("Task %s timed out after %.1f seconds", spec.task_id, elapsed) + ray.cancel(evaluation_future) + + return TaskResult( + task_id=spec.task_id, + success=False, + error=f"Task timed out after {elapsed:.1f} seconds", + duration_seconds=elapsed, + ) + + except Exception as e: + duration = time.time() - start_time + logger.exception("Task %s failed: %s", spec.task_id, e) + + return TaskResult( + task_id=spec.task_id, + success=False, + error=str(e), + duration_seconds=duration, + ) + + async def teardown(self, context: TaskContext) -> None: + """Clean up resources after task execution. + + Args: + context: The execution context + """ + spec = context.spec + job_id = spec.job_id or spec.task_id + submission_id = spec.submission_id + + logger.info("Starting teardown for task %s", spec.task_id) + + # Clean up queues + self._cleanup_queues(context) + + # Clean up Ray worker + cluster = context.state.get("cluster") + worker = context.state.get("worker") + if cluster and worker: + try: + ray.get(worker.cleanup.remote(), timeout=5) + except Exception as e: + logger.warning("Worker cleanup failed: %s", e) + try: + cluster.delete_worker(worker) + logger.info("Cleaned up Ray worker for task %s", spec.task_id) + except Exception as e: + logger.warning("Failed to delete worker: %s", e) + + # Clean up container + containers = context.state.get("containers") + if containers and context.state.get("container_ready"): + try: + containers.cleanup_container(submission_id, job_id) + logger.info( + "Cleaned up container for submission %s (task %s)", + submission_id, + spec.task_id, + ) + except Exception as e: + logger.warning("Failed to cleanup container: %s", e) + + # Force garbage collection + gc.collect() + + logger.info("Teardown complete for task %s", spec.task_id) + + # Helper methods + + def _build_benchmark_spec(self, spec: TaskSpec) -> BenchmarkSpec: + """Build a BenchmarkSpec from a TaskSpec.""" + config_payload = spec.config.get("config", spec.config) + if not isinstance(config_payload, dict): + config_payload = {} + + # Normalize camera names + camera_names = spec.config.get("camera_names") + if camera_names is None: + camera_names = ("corner",) + elif isinstance(camera_names, str): + camera_names = (camera_names,) + elif isinstance(camera_names, (list, tuple)): + camera_names = tuple(camera_names) + + return BenchmarkSpec( + provider=spec.env_provider or spec.config.get("env_provider", ""), + benchmark_name=spec.benchmark_name or spec.config.get("benchmark_name", ""), + config=config_payload, + render_mode=spec.config.get("render_mode", "rgb_array"), + camera_names=camera_names, + camera_attribute=spec.config.get("camera_attribute", "camera_name"), + ) + + def _build_success_result( + self, + task_id: str, + results: List[EnvResult], + duration: float, + ) -> TaskResult: + """Build a TaskResult from successful evaluation results.""" + if not results: + return TaskResult( + task_id=task_id, + success=True, + metrics={ + "success_rate": 0.0, + "avg_reward": 0.0, + "total_episodes": 0, + }, + duration_seconds=duration, + total_episodes=0, + env_results=results, + ) + + total_episodes = sum(len(result.episodes) for result in results) + avg_success_rate = sum(result.success_rate for result in results) / len(results) + avg_reward = sum(result.mean_reward for result in results) / len(results) + + return TaskResult( + task_id=task_id, + success=True, + metrics={ + "success_rate": avg_success_rate, + "avg_reward": avg_reward, + "total_episodes": float(total_episodes) if total_episodes else 0.0, + "num_environments": float(len(results)), + }, + duration_seconds=duration, + total_episodes=total_episodes or None, + env_results=results, + ) + + def _cleanup_queues(self, context: TaskContext) -> None: + """Clean up Ray Queue actors.""" + worker_to_rpc = context.state.get("worker_to_rpc_queue") + rpc_to_worker = context.state.get("rpc_to_worker_queue") + + for queue, name in [ + (worker_to_rpc, "worker_to_rpc"), + (rpc_to_worker, "rpc_to_worker"), + ]: + if queue is not None: + try: + if hasattr(queue, "actor") and queue.actor is not None: + queue.shutdown(force=True) + logger.debug("Shutdown %s queue", name) + except Exception as e: + logger.warning("Failed to shutdown %s queue: %s", name, e) + + async def _cleanup_on_setup_failure(self, context: TaskContext) -> None: + """Clean up resources after a setup failure.""" + self._cleanup_queues(context) + + cluster = context.state.get("cluster") + worker = context.state.get("worker") + if cluster and worker: + try: + cluster.delete_worker(worker) + except Exception as e: + logger.warning("Failed to delete worker during setup cleanup: %s", e) + + containers = context.state.get("containers") + if containers and context.state.get("container_ready"): + try: + job_id = context.spec.job_id or context.spec.task_id + containers.cleanup_container(context.spec.submission_id, job_id) + except Exception as e: + logger.warning( + "Failed to cleanup container during setup cleanup: %s", e + ) + + gc.collect() + + async def _wait_for_rpc_handshake( + self, + *, + job_id: Any, + worker, + worker_to_rpc_queue: Queue, + rpc_to_worker_queue: Queue, + ) -> None: + """Wait for RPC connection to be established.""" + max_attempts = max(1, self.rpc_handshake_max_attempts) + retry_seconds = max(0.0, self.rpc_handshake_retry_seconds) + + last_error = "no response received from RPC process" + for attempt in range(1, max_attempts + 1): + try: + response = await worker.test_rpc.remote( + worker_to_rpc_queue, rpc_to_worker_queue + ) + except Exception as exc: + last_error = str(exc) + logger.warning( + "RPC handshake attempt %d/%d for job %s failed: %s", + attempt, + max_attempts, + job_id, + exc, + ) + else: + if response and getattr(response, "success", False): + logger.info( + "RPC handshake succeeded for job %s on attempt %d", + job_id, + attempt, + ) + return + + response_error = getattr(response, "error_message", None) + last_error = response_error or "RPC response reported failure" + logger.warning( + "RPC handshake attempt %d/%d for job %s reported error: %s", + attempt, + max_attempts, + job_id, + last_error, + ) + + if attempt < max_attempts and retry_seconds > 0: + delay = retry_seconds * attempt + logger.info("Retrying RPC handshake for job %s in %.1fs", job_id, delay) + await asyncio.sleep(delay) + + raise RuntimeError( + f"Unable to establish RPC connection for job {job_id} " + f"after {max_attempts} attempts: {last_error}" + ) diff --git a/src/evaluator/executors/tests/__init__.py b/src/evaluator/executors/tests/__init__.py new file mode 100644 index 0000000..ba5d2ac --- /dev/null +++ b/src/evaluator/executors/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for task executors.""" diff --git a/src/evaluator/executors/tests/test_registry.py b/src/evaluator/executors/tests/test_registry.py new file mode 100644 index 0000000..1d9912c --- /dev/null +++ b/src/evaluator/executors/tests/test_registry.py @@ -0,0 +1,284 @@ +"""Tests for ExecutorRegistry.""" + +from datetime import timedelta + +import pytest + +from core.tasks import ( + ExecutorNotFoundError, + ResourceSpec, + TaskContext, + TaskResult, + TaskSpec, + TaskType, +) +from evaluator.executors.registry import ExecutorRegistry + + +class MockExecutor: + """Mock executor for testing.""" + + task_type = TaskType.RL_ROLLOUT + + async def validate_spec(self, spec: TaskSpec) -> list[str]: + return [] + + async def setup(self, spec: TaskSpec) -> TaskContext: + return TaskContext(spec=spec, work_dir="/tmp") + + async def execute(self, context: TaskContext) -> TaskResult: + return TaskResult(task_id=context.spec.task_id, success=True) + + async def teardown(self, context: TaskContext) -> None: + pass + + +class TestExecutorRegistry: + """Tests for ExecutorRegistry.""" + + def setup_method(self): + """Clear registry before each test.""" + ExecutorRegistry.clear() + + def teardown_method(self): + """Clear registry after each test.""" + ExecutorRegistry.clear() + + def test_register_executor(self): + """Test registering an executor.""" + executor = MockExecutor() + ExecutorRegistry.register(executor) + + assert ExecutorRegistry.has(TaskType.RL_ROLLOUT) + assert ExecutorRegistry.get(TaskType.RL_ROLLOUT) is executor + + def test_register_overwrites_existing(self): + """Test that registering overwrites existing executor.""" + executor1 = MockExecutor() + executor2 = MockExecutor() + + ExecutorRegistry.register(executor1) + ExecutorRegistry.register(executor2) + + assert ExecutorRegistry.get(TaskType.RL_ROLLOUT) is executor2 + + def test_get_nonexistent_raises(self): + """Test that getting nonexistent executor raises error.""" + with pytest.raises(ExecutorNotFoundError): + ExecutorRegistry.get(TaskType.RL_ROLLOUT) + + def test_get_optional_returns_none(self): + """Test that get_optional returns None for nonexistent.""" + assert ExecutorRegistry.get_optional(TaskType.RL_ROLLOUT) is None + + def test_get_optional_returns_executor(self): + """Test that get_optional returns executor when exists.""" + executor = MockExecutor() + ExecutorRegistry.register(executor) + + assert ExecutorRegistry.get_optional(TaskType.RL_ROLLOUT) is executor + + def test_has_returns_false_when_not_registered(self): + """Test has returns False when not registered.""" + assert ExecutorRegistry.has(TaskType.RL_ROLLOUT) is False + + def test_has_returns_true_when_registered(self): + """Test has returns True when registered.""" + ExecutorRegistry.register(MockExecutor()) + assert ExecutorRegistry.has(TaskType.RL_ROLLOUT) is True + + def test_unregister_existing(self): + """Test unregistering an existing executor.""" + ExecutorRegistry.register(MockExecutor()) + assert ExecutorRegistry.unregister(TaskType.RL_ROLLOUT) is True + assert ExecutorRegistry.has(TaskType.RL_ROLLOUT) is False + + def test_unregister_nonexistent(self): + """Test unregistering a nonexistent executor.""" + assert ExecutorRegistry.unregister(TaskType.RL_ROLLOUT) is False + + def test_list_types_empty(self): + """Test listing types when registry is empty.""" + assert ExecutorRegistry.list_types() == [] + + def test_list_types_with_executors(self): + """Test listing types with registered executors.""" + ExecutorRegistry.register(MockExecutor()) + types = ExecutorRegistry.list_types() + assert TaskType.RL_ROLLOUT in types + + def test_clear(self): + """Test clearing the registry.""" + ExecutorRegistry.register(MockExecutor()) + ExecutorRegistry.clear() + assert ExecutorRegistry.list_types() == [] + + +class TestTaskSpec: + """Tests for TaskSpec serialization.""" + + def test_to_dict(self): + """Test TaskSpec serialization to dict.""" + spec = TaskSpec( + task_type=TaskType.RL_ROLLOUT, + task_id="test-task-123", + config={"env_name": "test-env"}, + timeout=timedelta(hours=1), + resources=ResourceSpec(cpu_cores=2.0, memory_mb=4096), + submission_id=12345, + competition_id="comp-1", + miner_hotkey="5GrwvaEF5zXb26Fz9rcQpDWS57CtERHpNehXCPcNoHGKutQY", + artifact_url="https://example.com/artifact.tar.gz", + ) + + data = spec.to_dict() + + assert data["task_type"] == "rl_rollout" + assert data["task_id"] == "test-task-123" + assert data["config"] == {"env_name": "test-env"} + assert data["timeout_seconds"] == 3600.0 + assert data["resources"]["cpu_cores"] == 2.0 + assert data["resources"]["memory_mb"] == 4096 + + def test_from_dict(self): + """Test TaskSpec deserialization from dict.""" + data = { + "task_type": "rl_rollout", + "task_id": "test-task-456", + "config": {"benchmark": "MT1"}, + "timeout_seconds": 7200, + "resources": {"cpu_cores": 4.0, "gpu_count": 1}, + "submission_id": 67890, + "competition_id": "comp-2", + "miner_hotkey": "5FHneW46xGXgs5mUiveU4sbTyGBzmstUspZC92UhjJM694ty", + "artifact_url": "https://example.com/model.tar.gz", + } + + spec = TaskSpec.from_dict(data) + + assert spec.task_type == TaskType.RL_ROLLOUT + assert spec.task_id == "test-task-456" + assert spec.config == {"benchmark": "MT1"} + assert spec.timeout.total_seconds() == 7200 + assert spec.resources.cpu_cores == 4.0 + assert spec.resources.gpu_count == 1 + + def test_roundtrip_serialization(self): + """Test that to_dict/from_dict roundtrip preserves data.""" + original = TaskSpec( + task_type=TaskType.RL_ROLLOUT, + task_id="roundtrip-test", + config={"key": "value"}, + timeout=timedelta(minutes=30), + resources=ResourceSpec(), + submission_id=11111, + competition_id="comp-rt", + miner_hotkey="5FLSigC9HGRKVhB9FiEo4Y3koPsNmBmLJbpXg2mp1hXcS59Y", + artifact_url="https://example.com/test.tar.gz", + job_id=99999, + env_provider="metaworld", + benchmark_name="MT10", + ) + + data = original.to_dict() + restored = TaskSpec.from_dict(data) + + assert restored.task_type == original.task_type + assert restored.task_id == original.task_id + assert restored.config == original.config + assert restored.timeout == original.timeout + assert restored.resources.cpu_cores == original.resources.cpu_cores + assert restored.submission_id == original.submission_id + assert restored.job_id == original.job_id + assert restored.env_provider == original.env_provider + + +class TestResourceSpec: + """Tests for ResourceSpec.""" + + def test_default_values(self): + """Test default resource values.""" + spec = ResourceSpec() + + assert spec.cpu_cores == 1.0 + assert spec.memory_mb == 2048 + assert spec.gpu_count == 0 + assert spec.gpu_memory_mb == 0 + assert spec.storage_mb == 1024 + + def test_custom_values(self): + """Test custom resource values.""" + spec = ResourceSpec( + cpu_cores=8.0, + memory_mb=16384, + gpu_count=2, + gpu_memory_mb=8192, + storage_mb=10240, + ) + + assert spec.cpu_cores == 8.0 + assert spec.memory_mb == 16384 + assert spec.gpu_count == 2 + assert spec.gpu_memory_mb == 8192 + assert spec.storage_mb == 10240 + + def test_negative_cpu_raises(self): + """Test that negative CPU raises error.""" + with pytest.raises(ValueError, match="cpu_cores must be non-negative"): + ResourceSpec(cpu_cores=-1.0) + + def test_negative_memory_raises(self): + """Test that negative memory raises error.""" + with pytest.raises(ValueError, match="memory_mb must be non-negative"): + ResourceSpec(memory_mb=-1) + + def test_negative_gpu_raises(self): + """Test that negative GPU count raises error.""" + with pytest.raises(ValueError, match="gpu_count must be non-negative"): + ResourceSpec(gpu_count=-1) + + +class TestTaskResult: + """Tests for TaskResult.""" + + def test_success_result(self): + """Test creating a success result.""" + result = TaskResult( + task_id="task-1", + success=True, + metrics={"accuracy": 0.95}, + duration_seconds=120.5, + ) + + assert result.success is True + assert result.metrics["accuracy"] == 0.95 + assert result.error is None + + def test_failure_result(self): + """Test creating a failure result.""" + result = TaskResult( + task_id="task-2", + success=False, + error="Container failed to start", + duration_seconds=5.0, + ) + + assert result.success is False + assert result.error == "Container failed to start" + + def test_to_dict(self): + """Test TaskResult serialization.""" + result = TaskResult( + task_id="task-3", + success=True, + metrics={"success_rate": 0.8, "avg_reward": 1500.0}, + total_episodes=100, + duration_seconds=300.0, + ) + + data = result.to_dict() + + assert data["task_id"] == "task-3" + assert data["success"] is True + assert data["metrics"]["success_rate"] == 0.8 + assert data["total_episodes"] == 100 diff --git a/src/validator/migrations/env.py b/src/evaluator/migrations/env.py similarity index 90% rename from src/validator/migrations/env.py rename to src/evaluator/migrations/env.py index f1fe8ba..79042f6 100644 --- a/src/validator/migrations/env.py +++ b/src/evaluator/migrations/env.py @@ -1,5 +1,5 @@ """ -Alembic environment configuration for validator database. +Alembic environment configuration for evaluator database. """ import asyncio @@ -18,8 +18,8 @@ project_root = Path(__file__).parent.parent.parent.parent sys.path.insert(0, str(project_root)) -# Import the validator models to ensure they're registered with SQLModel -from validator.db import models # noqa: E402, F401 +# Import the evaluator models to ensure they're registered with SQLModel +from evaluator.db import models # noqa: E402, F401 # this is the Alembic Config object config = context.config @@ -40,15 +40,15 @@ def get_database_url(): if db_url: return db_url - # Try to read from validator config file directly + # Try to read from evaluator config file directly try: import toml # noqa: PLC0415 - config_path = project_root / "config" / "validator.toml" + config_path = project_root / "config" / "evaluator.toml" if config_path.exists(): config_data = toml.load(config_path) return config_data.get( - "pg_database", "postgresql://postgres@localhost/validator_db" + "pg_database", "postgresql://postgres@localhost/evaluator_db" ) except Exception: pass diff --git a/src/validator/migrations/script.py.mako b/src/evaluator/migrations/script.py.mako similarity index 100% rename from src/validator/migrations/script.py.mako rename to src/evaluator/migrations/script.py.mako diff --git a/src/validator/migrations/versions/001_initial_validator_schema.py b/src/evaluator/migrations/versions/001_initial_validator_schema.py similarity index 100% rename from src/validator/migrations/versions/001_initial_validator_schema.py rename to src/evaluator/migrations/versions/001_initial_validator_schema.py diff --git a/src/validator/migrations/versions/002_artifact_cols.py b/src/evaluator/migrations/versions/002_artifact_cols.py similarity index 100% rename from src/validator/migrations/versions/002_artifact_cols.py rename to src/evaluator/migrations/versions/002_artifact_cols.py diff --git a/src/validator/migrations/versions/003_env_specs_on_results.py b/src/evaluator/migrations/versions/003_env_specs_on_results.py similarity index 100% rename from src/validator/migrations/versions/003_env_specs_on_results.py rename to src/evaluator/migrations/versions/003_env_specs_on_results.py diff --git a/src/validator/migrations/versions/004_add_timeout_to_jobs.py b/src/evaluator/migrations/versions/004_add_timeout_to_jobs.py similarity index 100% rename from src/validator/migrations/versions/004_add_timeout_to_jobs.py rename to src/evaluator/migrations/versions/004_add_timeout_to_jobs.py diff --git a/src/evaluator/orchestrator.py b/src/evaluator/orchestrator.py index 520777b..e046ea6 100644 --- a/src/evaluator/orchestrator.py +++ b/src/evaluator/orchestrator.py @@ -1,43 +1,47 @@ +""" +Orchestrator - Executor-based task orchestration. + +Uses the TaskExecutor pattern for pluggable task type support. Maintains +backward compatibility with the existing EvalJobMessage format while adding +support for the new TaskSpec-based approach. + +Connects directly to the backend via WebSocket to receive jobs and send results. +""" + +from __future__ import annotations + import asyncio import copy -import functools import gc -import threading -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional -import asyncpg import ray from fiber.chain.chain_utils import load_hotkey_keypair -from kubernetes import client, config -from pgqueuer import PgQueuer, Queries -from pgqueuer.db import AsyncpgDriver -from pgqueuer.models import Job -from ray.util.queue import Queue from snowflake import SnowflakeGenerator from core.db.models import EvaluationStatus from core.log import configure_logging, get_logger from core.messages import EvalJobMessage, EvalResultMessage, JobStatusUpdateMessage +from core.tasks import ( + ExecutorNotFoundError, + ResourceSpec, + TaskResult, + TaskSpec, + TaskType, +) +from evaluator.backend_client import BackendClient from evaluator.config import EvaluatorConfig from evaluator.constants import ( EVAL_TIMEOUT, MIN_CONCURRENT_JOBS, - POD_LOG_TAIL_LINES, - PROCESS_JOB_WAIT_TIME, - QUEUE_MAXSIZE, - RAY_WAIT_TIMEOUT, - RESOURCE_BACKOFF_SECONDS, - WAIT_TIME, ) from evaluator.containers import Containers, PodSchedulingError +from evaluator.db.db_manager import DatabaseManager +from evaluator.db.models import EvaluationJob +from evaluator.executors import ExecutorRegistry, RLRolloutExecutor from evaluator.log_uploader import EvaluationLogUploader -from evaluator.rollout import BenchmarkSpec, EnvManager, RolloutCluster -from evaluator.rollout.envs import EnvResult, EnvSpec -from evaluator.rpc.rpc_process import RPCProcess -from validator.db.db_manager import DatabaseManager -from validator.db.models import EvaluationJob logger = get_logger(__name__) @@ -55,16 +59,37 @@ def _sanitize_for_json(value: Any) -> Any: if isinstance(value, (list, tuple, set)): return [_sanitize_for_json(item) for item in value] try: - dumped = value.model_dump() # type: ignore[attr-defined] + dumped = value.model_dump() return _sanitize_for_json(dumped) except Exception: return str(value) class Orchestrator: + """ + Executor-based orchestrator for task evaluation. + + Uses the TaskExecutor pattern to support multiple task types. Delegates + setup, execution, and teardown to registered executors, allowing new task + types to be added without modifying the orchestrator itself. + + Connects directly to the backend via WebSocket to receive jobs and + send results. + + Usage: + config = EvaluatorConfig() + orchestrator = Orchestrator(config) + await orchestrator.start() + """ + def __init__(self, config: EvaluatorConfig): self.config = config - logger.info(f"Orchestrator initialized with db: {self.config.pg_database}") + logger.info("Orchestrator initialized with db: %s", self.config.pg_database) + + # Backend client for direct WebSocket connection + self.backend_client: Optional[BackendClient] = None + + # Database and identity self.db = DatabaseManager(self.config.pg_database) self.id_generator = SnowflakeGenerator(42) self.keypair = load_hotkey_keypair( @@ -72,14 +97,14 @@ def __init__(self, config: EvaluatorConfig): hotkey_name=config.settings["hotkey_name"], ) + # Log uploader self.log_uploader: Optional[EvaluationLogUploader] = ( EvaluationLogUploader(self.config.s3_config) if self.config.s3_config else None ) - # Track running jobs for concurrent execution - self.running_jobs: Dict[str, Dict] = {} # job_id -> job_info + # Concurrency control if config.max_concurrent_jobs < MIN_CONCURRENT_JOBS: logger.warning( "Configured max_concurrent_jobs (%s) below minimum (%s); clamping.", @@ -90,13 +115,21 @@ def __init__(self, config: EvaluatorConfig): self.default_job_timeout = self._resolve_job_timeout( config.settings.get("job_timeout"), int(EVAL_TIMEOUT.total_seconds()) ) - logger.info( - "Default job timeout fallback set to %s seconds (backend may override)", - self.default_job_timeout, - ) self.concurrent_slots = asyncio.Semaphore(self.max_concurrent_jobs) - # Initialize Ray with explicit configuration + # Running jobs tracking + self.running_jobs: Dict[str, Dict[str, Any]] = {} + + # Initialize Ray + self._init_ray() + + # Register default executors + self._register_default_executors() + + logger.info("Orchestrator initialized with config: %s", self.config) + + def _init_ray(self) -> None: + """Initialize Ray with explicit configuration.""" if not ray.is_initialized(): init_kwargs = { "num_cpus": self.config.ray_num_cpus, @@ -115,23 +148,50 @@ def __init__(self, config: EvaluatorConfig): else: logger.info("Ray already initialized") - logger.info(f"Orchestrator initialized with config: {self.config}") + def _register_default_executors(self) -> None: + """Register the default task executors and initialize providers.""" + # Initialize default environment providers + from evaluator.providers.registry import ProviderRegistry + + ProviderRegistry.initialize_default_providers() + logger.info("Initialized providers: %s", ProviderRegistry.list_providers()) + + # Register RL rollout executor + rl_executor = RLRolloutExecutor(self.config) + ExecutorRegistry.register(rl_executor) + logger.info("Registered default executors: %s", ExecutorRegistry.list_types()) + + def _init_backend_client(self) -> BackendClient: + """Initialize the BackendClient for direct WebSocket connection.""" + backend_ws_url = getattr(self.config, "backend_ws_url", None) + if not backend_ws_url: + raise ValueError( + "backend_ws_url not configured. Set backend_ws_url in evaluator config." + ) - @staticmethod - def _normalize_camera_names(value: Any) -> tuple[str, ...]: - """Convert incoming camera name payloads into a tuple of strings.""" - if value is None: - return tuple() - if isinstance(value, str): - return (value,) - if isinstance(value, (list, tuple, set)): - names: list[str] = [] - for item in value: - if item is None: - continue - names.append(str(item)) - return tuple(names) - return tuple() + evaluator_id = getattr(self.config, "evaluator_id", None) + if not evaluator_id: + # Generate a default evaluator ID from hotkey + evaluator_id = f"evaluator-{self.keypair.ss58_address[:16]}" + logger.warning( + "No evaluator_id configured, using generated ID: %s", evaluator_id + ) + + # Get supported task types from registered executors + supported_task_types = [t.value for t in ExecutorRegistry.list_types()] + + return BackendClient( + backend_url=backend_ws_url, + evaluator_id=evaluator_id, + supported_task_types=supported_task_types, + max_concurrent_jobs=self.max_concurrent_jobs, + capabilities={ + "ray_num_cpus": self.config.ray_num_cpus, + "ray_num_gpus": self.config.ray_num_gpus, + "hotkey": self.keypair.ss58_address, + }, + on_job_received=self._handle_job_from_backend, + ) @staticmethod def _resolve_job_timeout(value: Any, fallback: int) -> int: @@ -146,141 +206,57 @@ def _resolve_job_timeout(value: Any, fallback: int) -> int: return timeout_seconds - def _split_config_data( - self, config_value: Dict[str, Any] - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Split a stored configuration into (spec, base_config) copies.""" - spec_copy = copy.deepcopy(config_value) - try: - base_config_source = config_value["config"] - except KeyError as exc: # pragma: no cover - defensive guard - raise ValueError("Benchmark spec is missing 'config'") from exc - try: - base_config = copy.deepcopy(dict(base_config_source)) - except TypeError as exc: # pragma: no cover - defensive guard - raise ValueError("Benchmark spec 'config' must be a mapping") from exc - return spec_copy, base_config - - def _extract_job_spec_payloads( - self, eval_job_msg: EvalJobMessage - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Return the benchmark spec payload and the underlying config dict.""" - if eval_job_msg.benchmark_spec is None: - raise ValueError("EvalJobMessage is missing required benchmark_spec data") - return self._split_config_data(eval_job_msg.benchmark_spec) - - def _config_payload_for_storage(self, eval_job_msg: EvalJobMessage) -> dict: - """Choose the representation of the job configuration to persist locally.""" - spec_payload, _ = self._extract_job_spec_payloads(eval_job_msg) - return spec_payload - - def _build_benchmark_spec_from_job( - self, eval_job_msg: EvalJobMessage - ) -> BenchmarkSpec: - """Construct a BenchmarkSpec using job metadata and optional spec payload.""" - spec_payload, base_config = self._extract_job_spec_payloads(eval_job_msg) - - defaults = BenchmarkSpec.__dataclass_fields__ - default_render_mode = defaults["render_mode"].default - default_camera_names = defaults["camera_names"].default - default_camera_attribute = defaults["camera_attribute"].default - - render_mode = spec_payload.get("render_mode") - camera_attribute = spec_payload.get("camera_attribute") - camera_names = self._normalize_camera_names(spec_payload.get("camera_names")) - - return BenchmarkSpec( - provider=eval_job_msg.env_provider, - benchmark_name=eval_job_msg.benchmark_name, - config=base_config, - render_mode=render_mode or default_render_mode, - camera_names=camera_names or default_camera_names, - camera_attribute=camera_attribute or default_camera_attribute, - ) - - def _serialize_env_spec(self, env_spec: EnvSpec) -> Dict[str, Any]: - """Convert an EnvSpec into a JSON-safe payload for storage.""" - return { - "env_name": env_spec.env_name, - "benchmark_name": env_spec.benchmark_name, - "provider": env_spec.provider, - "config": _sanitize_for_json(getattr(env_spec, "config", {})), - "episodes_per_task": env_spec.episodes_per_task, - "max_episode_steps": env_spec.max_episode_steps, - "render_mode": env_spec.render_mode, - "camera_attribute": env_spec.camera_attribute, - "camera_names": list(env_spec.camera_names), - } - - def _build_env_specs_payload( - self, - *, - eval_job_msg: EvalJobMessage, - env_results: Optional[List[EnvResult]] = None, - ) -> List[Dict[str, Any]]: - """Capture the environment specs used (or intended) for a job.""" - env_specs_payload: List[Dict[str, Any]] = [] - - if env_results: - for env_result in env_results: - try: - env_specs_payload.append( - self._serialize_env_spec(env_result.env_spec) - ) - except Exception as exc: # pragma: no cover - defensive guard - logger.warning( - "Failed to serialize env spec for job %s: %s", - eval_job_msg.job_id, - exc, - ) - - if env_specs_payload: - return env_specs_payload - - try: - benchmark_spec = self._build_benchmark_spec_from_job(eval_job_msg) - env_manager = EnvManager() - derived_env_specs = env_manager.get_benchmark_envs(benchmark_spec) - env_specs_payload = [ - self._serialize_env_spec(env_spec) for env_spec in derived_env_specs - ] - except Exception as exc: # pragma: no cover - defensive guard - logger.warning( - "Unable to derive env specs for job %s: %s", - eval_job_msg.job_id, - exc, - ) - - return env_specs_payload - - async def setup_job(self, job: Job) -> Optional[Dict]: - """Setup job infrastructure and return job context for monitoring.""" - logger.info(f"Setting up job: {job.id}") - if not job.payload: - return None + def _task_spec_from_eval_job(self, eval_job_msg: EvalJobMessage) -> TaskSpec: + """Convert an EvalJobMessage to a TaskSpec. - eval_job_msg = EvalJobMessage.from_bytes(job.payload) - if ( - eval_job_msg.artifact_url - and eval_job_msg.artifact_expires_at - and eval_job_msg.artifact_expires_at <= datetime.now(timezone.utc) - ): - logger.warning( - "Received expired artifact URL for job %s (expires_at=%s)", - eval_job_msg.job_id, - eval_job_msg.artifact_expires_at, + This provides backward compatibility with the existing message format. + """ + # Extract config from benchmark_spec or config field + if eval_job_msg.benchmark_spec: + config_payload = copy.deepcopy(eval_job_msg.benchmark_spec) + else: + config_payload = ( + copy.deepcopy(eval_job_msg.config) if eval_job_msg.config else {} ) - job_timeout_seconds = self._resolve_job_timeout( + # Determine timeout + timeout_seconds = self._resolve_job_timeout( getattr(eval_job_msg, "timeout", None).total_seconds() if getattr(eval_job_msg, "timeout", None) else None, self.default_job_timeout, ) - job_config_payload = self._config_payload_for_storage(eval_job_msg) + return TaskSpec( + task_type=TaskType.RL_ROLLOUT, # Default to RL for backward compatibility + task_id=str(eval_job_msg.job_id), + config=config_payload, + timeout=timedelta(seconds=timeout_seconds), + resources=ResourceSpec(), # Use defaults + submission_id=eval_job_msg.submission_id, + competition_id=eval_job_msg.competition_id, + miner_hotkey=eval_job_msg.miner_hotkey, + artifact_url=eval_job_msg.artifact_url or "", + artifact_sha256=eval_job_msg.artifact_sha256, + artifact_size_bytes=eval_job_msg.artifact_size_bytes, + artifact_expires_at=eval_job_msg.artifact_expires_at, + job_id=eval_job_msg.job_id, + hf_repo_id=eval_job_msg.hf_repo_id, + env_provider=eval_job_msg.env_provider, + benchmark_name=eval_job_msg.benchmark_name, + ) + + def _create_evaluation_job_record( + self, eval_job_msg: EvalJobMessage, timeout_seconds: int + ) -> EvaluationJob: + """Create an EvaluationJob database record from an EvalJobMessage.""" + config_payload = ( + copy.deepcopy(eval_job_msg.benchmark_spec) + if eval_job_msg.benchmark_spec + else {} + ) - evaluation_job = EvaluationJob( + return EvaluationJob( id=eval_job_msg.job_id, competition_id=eval_job_msg.competition_id, submission_id=eval_job_msg.submission_id, @@ -288,8 +264,8 @@ async def setup_job(self, job: Job) -> Optional[Dict]: hf_repo_id=eval_job_msg.hf_repo_id, env_provider=eval_job_msg.env_provider, benchmark_name=eval_job_msg.benchmark_name, - config=job_config_payload, - timeout_seconds=job_timeout_seconds, + config=config_payload, + timeout_seconds=timeout_seconds, artifact_url=eval_job_msg.artifact_url, artifact_expires_at=eval_job_msg.artifact_expires_at, artifact_sha256=eval_job_msg.artifact_sha256, @@ -297,612 +273,441 @@ async def setup_job(self, job: Job) -> Optional[Dict]: created_at=datetime.now(timezone.utc), ) - existing_job = self.db.get_evaluation_job(eval_job_msg.job_id) - - # Create job entry in the database with QUEUED status - if existing_job: - logger.info( - "Existing evaluation job record found for %s; skipping creation", - eval_job_msg.job_id, - ) - try: - self.db.update_evaluation_job( - eval_job_msg.job_id, - { - "config": job_config_payload, - "timeout_seconds": job_timeout_seconds, - }, - ) - except Exception as e: - logger.warning( - "Failed to refresh stored timeout for job %s: %s", - eval_job_msg.job_id, - e, - ) - else: - try: - self.db.create_evaluation_job(evaluation_job) - except Exception as e: - logger.error( - f"Failed to create evaluation job {eval_job_msg.job_id} in DB: {e}" - ) - return None - - # Update status to STARTING - try: - self.db.update_evaluation_job( - eval_job_msg.job_id, - { - "status": EvaluationStatus.STARTING, - "started_at": datetime.now(timezone.utc), - "error_message": None, - "completed_at": None, - }, - ) - except Exception as e: - logger.error(f"Failed to update job status to STARTING: {e}") - else: - try: - status_msg = JobStatusUpdateMessage( - job_id=eval_job_msg.job_id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.STARTING, - detail="Evaluator is preparing the environment", - ) - await self.db.queue_job_status_update_msg(status_msg) - except Exception as e: - logger.error( - f"Failed to queue job status update for STARTING state: {e}" - ) - - if not eval_job_msg.artifact_url: - raise RuntimeError( - f"Job {eval_job_msg.job_id} missing artifact URL; cannot start container" - ) + async def _handle_job_from_backend(self, eval_job_msg: EvalJobMessage) -> None: + """ + Handle a job received via WebSocket from the backend. + This method is called by the BackendClient when a new job is received. + """ logger.info( - "Creating container for job %s using artifact %s", + "Received job %s from backend via WebSocket for competition %s", eval_job_msg.job_id, - eval_job_msg.artifact_url, + eval_job_msg.competition_id, ) - containers = Containers() - cluster: Optional[RolloutCluster] = None - worker: Optional[ray.actor.ActorHandle] = None - worker_to_rpc_queue: Optional[Queue] = None - rpc_to_worker_queue: Optional[Queue] = None - container_ready = False + # Process the job using existing infrastructure + await self.process_job(eval_job_msg) - try: - pod = containers.create_container( - eval_job_msg.submission_id, - eval_job_msg.job_id, - archive_url=eval_job_msg.artifact_url, - archive_sha256=eval_job_msg.artifact_sha256, - ) - container_ready = True - logger.info(f"Created pod: {pod}") - - # Get NodePort and Node IP for direct TCP connection - config.load_kube_config() - k8v1api = client.CoreV1Api() - v1 = client.CoreV1Api() - service_name = pod - svc = k8v1api.read_namespaced_service(service_name, "default") - node_port = None - for port in svc.spec.ports: - if port.node_port: - node_port = port.node_port - break - if not node_port: - raise RuntimeError(f"No nodePort found for service {service_name}") - - # Get the first node's external IP (or internal if not available) - nodes = v1.list_node().items - node_ip = None - for node in nodes: - for addr in node.status.addresses: - if addr.type == "ExternalIP": - node_ip = addr.address - break - if not node_ip: - for addr in node.status.addresses: - if addr.type == "InternalIP": - node_ip = addr.address - break - if node_ip: - break - if not node_ip: - raise RuntimeError("No node IP found in cluster") - - # Wait for container to be ready - await asyncio.sleep(WAIT_TIME.total_seconds()) - - # Create a benchmark spec for the job, honoring backend-provided settings - benchmark_spec = self._build_benchmark_spec_from_job(eval_job_msg) - - worker_to_rpc_queue = Queue(maxsize=QUEUE_MAXSIZE) - rpc_to_worker_queue = Queue(maxsize=QUEUE_MAXSIZE) + async def setup_job(self, eval_job_msg: EvalJobMessage) -> Optional[Dict[str, Any]]: + """Set up job infrastructure using the appropriate executor. - logger.info( - f"Creating rollout cluster with config: {self.config.worker_remote_options}" - ) - cluster = RolloutCluster( - "eval-cluster", - worker_remote_options=self.config.worker_remote_options, - ) - worker = cluster.create_worker( + Returns a job context dict for monitoring, or None if setup fails. + """ + logger.info("Setting up job: %s", eval_job_msg.job_id) + + # Check for expired artifact + if ( + eval_job_msg.artifact_url + and eval_job_msg.artifact_expires_at + and eval_job_msg.artifact_expires_at <= datetime.now(timezone.utc) + ): + logger.warning( + "Received expired artifact URL for job %s (expires_at=%s)", eval_job_msg.job_id, - [benchmark_spec], - node_ip, - node_port, - eval_job_msg.submission_id, - s3_config=self.config.s3_config, - episode_log_interval=self.config.episode_log_interval, - step_log_interval=self.config.step_log_interval, - database_url=self.config.pg_database, + eval_job_msg.artifact_expires_at, ) - rpc_thread = threading.Thread( - target=RPCProcess, - args=(node_ip, node_port, rpc_to_worker_queue, worker_to_rpc_queue), - daemon=True, - ) - rpc_thread.start() + # Convert to TaskSpec + task_spec = self._task_spec_from_eval_job(eval_job_msg) - await asyncio.sleep(PROCESS_JOB_WAIT_TIME.total_seconds()) + # Get the appropriate executor + try: + executor = ExecutorRegistry.get(task_spec.task_type) + except ExecutorNotFoundError as e: + logger.error("No executor for task type %s: %s", task_spec.task_type, e) + return None - await self._wait_for_rpc_handshake( - job_id=eval_job_msg.job_id, - worker=worker, - worker_to_rpc_queue=worker_to_rpc_queue, - rpc_to_worker_queue=rpc_to_worker_queue, - ) - except Exception: - logger.exception( - "Failed to prepare evaluation runtime for job %s", + # Validate the task spec + validation_errors = await executor.validate_spec(task_spec) + if validation_errors: + logger.error( + "Task spec validation failed for job %s: %s", eval_job_msg.job_id, + validation_errors, ) - self._teardown_failed_job_setup( - job_id=eval_job_msg.job_id, - submission_id=eval_job_msg.submission_id, - cluster=cluster, - worker=worker, - worker_to_rpc_queue=worker_to_rpc_queue, - rpc_to_worker_queue=rpc_to_worker_queue, - containers=containers, - container_ready=container_ready, - ) - raise + await self._handle_validation_failure(eval_job_msg, validation_errors) + return None - assert worker is not None - assert cluster is not None - assert worker_to_rpc_queue is not None - assert rpc_to_worker_queue is not None + # Create/update database record + timeout_seconds = int(task_spec.timeout.total_seconds()) + evaluation_job = self._create_evaluation_job_record( + eval_job_msg, timeout_seconds + ) - # Update status to RUNNING - try: + existing_job = self.db.get_evaluation_job(eval_job_msg.job_id) + if existing_job: + logger.info( + "Existing evaluation job record found for %s; updating config", + eval_job_msg.job_id, + ) self.db.update_evaluation_job( - eval_job_msg.job_id, {"status": EvaluationStatus.RUNNING} + eval_job_msg.job_id, + { + "config": evaluation_job.config, + "timeout_seconds": timeout_seconds, + }, ) - except Exception as e: - logger.error(f"Failed to update job status to RUNNING: {e}") else: try: - status_msg = JobStatusUpdateMessage( - job_id=eval_job_msg.job_id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.RUNNING, - detail="Evaluator started processing the job", - ) - await self.db.queue_job_status_update_msg(status_msg) + self.db.create_evaluation_job(evaluation_job) except Exception as e: logger.error( - f"Failed to queue job status update for RUNNING state: {e}" + "Failed to create evaluation job %s in DB: %s", + eval_job_msg.job_id, + e, ) + return None + + # Update status to STARTING + await self._update_job_status( + eval_job_msg, + EvaluationStatus.STARTING, + "Evaluator is preparing the environment", + ) + + # Delegate setup to executor + try: + task_context = await executor.setup(task_spec) + except Exception as e: + logger.exception("Executor setup failed for job %s", eval_job_msg.job_id) + await self._handle_setup_failure(eval_job_msg, str(e)) + return None - # Start the evaluation (non-blocking) - logger.info(f"Starting evaluation for job {eval_job_msg.job_id}") - evaluation_future = worker.run_all_benchmark_tasks.remote( - worker_to_rpc_queue, rpc_to_worker_queue + # Update status to RUNNING + await self._update_job_status( + eval_job_msg, + EvaluationStatus.RUNNING, + "Evaluator started processing the job", ) - job_context = { + # Return job context for monitoring + return { "job_id": eval_job_msg.job_id, "submission_id": eval_job_msg.submission_id, "eval_job_msg": eval_job_msg, - "worker": worker, - "cluster": cluster, - "evaluation_future": evaluation_future, - "worker_to_rpc_queue": worker_to_rpc_queue, - "rpc_to_worker_queue": rpc_to_worker_queue, + "task_spec": task_spec, + "task_context": task_context, + "executor": executor, "start_time": datetime.now(timezone.utc), - "timeout_seconds": job_timeout_seconds, + "timeout_seconds": timeout_seconds, } - logger.info( - f"Created job context for {eval_job_msg.job_id} with queues: worker_to_rpc={worker_to_rpc_queue is not None}, rpc_to_worker={rpc_to_worker_queue is not None}" - ) + async def monitor_job(self, job_context: Dict[str, Any]) -> bool: + """Monitor a running job and handle completion. - return job_context + Returns True if the job is complete (success, failure, or timeout). + """ + job_id = job_context["job_id"] + task_context = job_context["task_context"] + executor = job_context["executor"] + start_time = job_context["start_time"] + timeout_seconds = job_context["timeout_seconds"] + + # Check for timeout + elapsed = (datetime.now(timezone.utc) - start_time).total_seconds() + if elapsed > timeout_seconds: + logger.error("Job %s timed out after %.1f seconds", job_id, elapsed) + await self._handle_timeout(job_context, elapsed, timeout_seconds) + return True - async def _wait_for_rpc_handshake( - self, - *, - job_id: Any, - worker, - worker_to_rpc_queue: Queue, - rpc_to_worker_queue: Queue, - ) -> None: - """Ensure the RPC client can reach the submission container before proceeding.""" - max_attempts = max(1, getattr(self.config, "rpc_handshake_max_attempts", 5)) - retry_seconds = max( - 0.0, getattr(self.config, "rpc_handshake_retry_seconds", 2.0) - ) + try: + # Execute the task (this may return immediately if already complete) + result = await executor.execute(task_context) - last_error = "no response received from RPC process" - for attempt in range(1, max_attempts + 1): - try: - response = await worker.test_rpc.remote( - worker_to_rpc_queue, rpc_to_worker_queue - ) - except Exception as exc: - last_error = str(exc) - logger.warning( - "RPC handshake attempt %d/%d for job %s failed: %s", - attempt, - max_attempts, - job_id, - exc, - ) + if result.success: + await self._handle_success(job_context, result) else: - if response and getattr(response, "success", False): - logger.info( - "RPC handshake succeeded for job %s on attempt %d", - job_id, - attempt, - ) - return - - response_error = getattr(response, "error_message", None) - last_error = response_error or "RPC response reported failure" - logger.warning( - "RPC handshake attempt %d/%d for job %s reported error: %s", - attempt, - max_attempts, - job_id, - last_error, - ) + await self._handle_failure(job_context, result) - if attempt < max_attempts and retry_seconds > 0: - delay = retry_seconds * attempt - logger.info("Retrying RPC handshake for job %s in %.1fs", job_id, delay) - await asyncio.sleep(delay) + return True - raise RuntimeError( - f"Unable to establish RPC connection for job {job_id} after {max_attempts} attempts: {last_error}" - ) + except Exception as e: + logger.exception("Error monitoring job %s", job_id) + await self._handle_error(job_context, str(e)) + return True - def _teardown_failed_job_setup( - self, - *, - job_id: Any, - submission_id: Optional[Any], - cluster: Optional[RolloutCluster], - worker: Optional[Any], - worker_to_rpc_queue: Optional[Queue], - rpc_to_worker_queue: Optional[Queue], - containers: Containers, - container_ready: bool, + async def _handle_validation_failure( + self, eval_job_msg: EvalJobMessage, errors: list[str] ) -> None: - """Best-effort cleanup for failures that occur before a job starts running.""" - if worker_to_rpc_queue is not None or rpc_to_worker_queue is not None: - temp_context = { - "job_id": job_id, - "worker_to_rpc_queue": worker_to_rpc_queue, - "rpc_to_worker_queue": rpc_to_worker_queue, - } - try: - self._cleanup_queues(temp_context) - except Exception as exc: - logger.warning( - "Queue cleanup failed during setup teardown for job %s: %s", - job_id, - exc, - ) + """Handle task spec validation failure.""" + error_message = f"Task validation failed: {'; '.join(errors)}" + completed_at = datetime.now(timezone.utc) - if cluster and worker: - try: - cluster.delete_worker(worker) - except Exception as exc: - logger.warning( - "Failed to delete rollout worker during setup teardown for job %s: %s", - job_id, - exc, - ) + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.FAILED, + "error_message": error_message, + "completed_at": completed_at, + }, + ) - if submission_id is not None and container_ready: - try: - containers.cleanup_container(submission_id, job_id) - except Exception as exc: - logger.warning( - "Failed to cleanup container for job %s during setup teardown: %s", - job_id, - exc, - ) + await self._publish_failure_result( + eval_job_msg=eval_job_msg, + status=EvaluationStatus.FAILED, + completed_at=completed_at, + error_message=error_message, + ) - gc.collect() + async def _handle_setup_failure( + self, eval_job_msg: EvalJobMessage, error: str + ) -> None: + """Handle executor setup failure.""" + completed_at = datetime.now(timezone.utc) + error_message = f"Setup failed: {error}" - def _build_log_payload( - self, - *, - job_context: Dict[str, Any], - eval_job_msg: EvalJobMessage, - status: EvaluationStatus, - summary: Dict[str, Any], - completed_at: datetime, - error: Optional[str] = None, - pod_logs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """Create a structured payload describing an evaluation job.""" - start_time = job_context.get("start_time") - if isinstance(start_time, datetime): - duration_seconds = (completed_at - start_time).total_seconds() - started_at = start_time - else: - duration_seconds = None - started_at = None - - spec_payload, base_config = self._extract_job_spec_payloads(eval_job_msg) - - payload: Dict[str, Any] = { - "schema_version": 1, - "generated_at": datetime.now(timezone.utc).isoformat(), - "job": { - # NOTE: str these ids for JS compat. - "id": str(job_context.get("job_id")), - "submission_id": str(eval_job_msg.submission_id), - "competition_id": eval_job_msg.competition_id, - "miner_hotkey": eval_job_msg.miner_hotkey, - "validator_hotkey": self.keypair.ss58_address, - "hf_repo_id": eval_job_msg.hf_repo_id, - "benchmark_name": eval_job_msg.benchmark_name, - "env_provider": eval_job_msg.env_provider, - "config": _sanitize_for_json(base_config), - "status": status.value - if isinstance(status, EvaluationStatus) - else status, - "started_at": _sanitize_for_json(started_at), - "completed_at": completed_at.astimezone(timezone.utc).isoformat(), + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.FAILED, + "error_message": error_message, + "completed_at": completed_at, }, - "summary": _sanitize_for_json(summary), - } - - payload["job"]["benchmark_spec"] = _sanitize_for_json(spec_payload) - - if duration_seconds is not None: - payload["job"]["duration_seconds"] = duration_seconds + ) - if pod_logs: - payload["pod_logs"] = _sanitize_for_json(pod_logs) + await self._publish_failure_result( + eval_job_msg=eval_job_msg, + status=EvaluationStatus.FAILED, + completed_at=completed_at, + error_message=error_message, + ) - if error: - payload["error"] = error + async def _handle_success( + self, job_context: Dict[str, Any], result: TaskResult + ) -> None: + """Handle successful task completion.""" + eval_job_msg = job_context["eval_job_msg"] + executor = job_context["executor"] + task_context = job_context["task_context"] + completed_at = datetime.now(timezone.utc) - return payload + logger.info( + "Job %s completed successfully: %s", + job_context["job_id"], + result.metrics, + ) - async def _upload_job_log_bundle( - self, - *, - job_context: Dict[str, Any], - eval_job_msg: EvalJobMessage, - status: EvaluationStatus, - summary: Dict[str, Any], - completed_at: datetime, - error: Optional[str] = None, - pod_logs: Optional[Dict[str, Any]] = None, - ) -> Optional[Dict[str, Any]]: - """Serialize and upload the job log bundle, returning storage metadata.""" - if not self.log_uploader: - return None + # Update database + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.COMPLETED, + "completed_at": completed_at, + }, + ) - payload = self._build_log_payload( - job_context=job_context, + # Publish result + await self._publish_success_result( eval_job_msg=eval_job_msg, - status=status, - summary=summary, + result=result, completed_at=completed_at, - error=error, - pod_logs=pod_logs, ) - loop = asyncio.get_running_loop() + # Teardown try: - metadata = await loop.run_in_executor( - None, - functools.partial( - self.log_uploader.upload_log_bundle, - submission_id=eval_job_msg.submission_id, - job_id=eval_job_msg.job_id, - payload=payload, - ), - ) - return metadata - except Exception: - logger.exception( - "Failed to upload evaluation log bundle for job %s", - eval_job_msg.job_id, - ) - return None + await executor.teardown(task_context) + except Exception as e: + logger.warning("Teardown failed for job %s: %s", job_context["job_id"], e) + + async def _handle_failure( + self, job_context: Dict[str, Any], result: TaskResult + ) -> None: + """Handle task execution failure.""" + eval_job_msg = job_context["eval_job_msg"] + executor = job_context["executor"] + task_context = job_context["task_context"] + completed_at = datetime.now(timezone.utc) - async def _enqueue_job_for_processing(self, eval_job_msg: EvalJobMessage) -> None: - """Requeue an evaluation job onto PgQueuer for processing.""" - conn = await asyncpg.connect(dsn=self.config.pg_database) + logger.error("Job %s failed: %s", job_context["job_id"], result.error) + + # Update database + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.FAILED, + "error_message": result.error, + "completed_at": completed_at, + }, + ) + + # Publish result + await self._publish_failure_result( + eval_job_msg=eval_job_msg, + status=EvaluationStatus.FAILED, + completed_at=completed_at, + error_message=result.error or "Unknown error", + ) + + # Teardown try: - driver = AsyncpgDriver(conn) - q = Queries(driver) - await q.enqueue(["add_job"], [eval_job_msg.to_bytes()], [0]) - logger.info("Requeued job %s onto add_job queue", eval_job_msg.job_id) - except Exception: - logger.exception( - "Failed to enqueue job %s for restart", eval_job_msg.job_id - ) - raise - finally: - await conn.close() + await executor.teardown(task_context) + except Exception as e: + logger.warning("Teardown failed for job %s: %s", job_context["job_id"], e) - def _cleanup_queues(self, job_context: Dict): - """Clean up Ray Queue actors for a job.""" - job_id = job_context.get("job_id") + async def _handle_timeout( + self, + job_context: Dict[str, Any], + elapsed: float, + timeout_seconds: int, + ) -> None: + """Handle task timeout.""" + eval_job_msg = job_context["eval_job_msg"] + executor = job_context["executor"] + task_context = job_context["task_context"] + completed_at = datetime.now(timezone.utc) - # Track cleanup calls - if not hasattr(self, "_cleanup_calls"): - self._cleanup_calls = {} - self._cleanup_calls[job_id] = self._cleanup_calls.get(job_id, 0) + 1 + error_message = ( + f"Job timed out after {elapsed:.1f} seconds (limit {timeout_seconds}s)" + ) - logger.info(f"Cleanup call #{self._cleanup_calls[job_id]} for job {job_id}") + # Update database + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.TIMEOUT, + "error_message": error_message, + "completed_at": completed_at, + }, + ) - # Debug: Log what's in job_context - logger.info(f"Job context keys for job {job_id}: {list(job_context.keys())}") + # Publish result + await self._publish_failure_result( + eval_job_msg=eval_job_msg, + status=EvaluationStatus.TIMEOUT, + completed_at=completed_at, + error_message=error_message, + ) - logger.info(f"Starting queue cleanup for job {job_id}") + # Teardown try: - # Get queue actors - worker_to_rpc = job_context.get("worker_to_rpc_queue") - rpc_to_worker = job_context.get("rpc_to_worker_queue") - - logger.info( - f"Retrieved from context - worker_to_rpc type: {type(worker_to_rpc)}, rpc_to_worker type: {type(rpc_to_worker)}" - ) + await executor.teardown(task_context) + except Exception as e: + logger.warning("Teardown failed for job %s: %s", job_context["job_id"], e) - # Shutdown queue actors - if worker_to_rpc is not None: - try: - if ( - hasattr(worker_to_rpc, "actor") - and worker_to_rpc.actor is not None - ): - logger.info( - f"Shutting down worker_to_rpc queue for job {job_id}" - ) - worker_to_rpc.shutdown(force=True) - logger.info( - f"Successfully shutdown worker_to_rpc queue for job {job_id}" - ) - else: - logger.info( - f"worker_to_rpc queue already shutdown for job {job_id}" - ) - except Exception as e: - logger.warning(f"Failed to shutdown worker_to_rpc queue: {e}") - else: - logger.warning(f"worker_to_rpc queue is None for job {job_id}") + async def _handle_error(self, job_context: Dict[str, Any], error: str) -> None: + """Handle unexpected error during monitoring.""" + eval_job_msg = job_context["eval_job_msg"] + executor = job_context["executor"] + task_context = job_context["task_context"] + completed_at = datetime.now(timezone.utc) - if rpc_to_worker is not None: - try: - if ( - hasattr(rpc_to_worker, "actor") - and rpc_to_worker.actor is not None - ): - logger.info( - f"Shutting down rpc_to_worker queue for job {job_id}" - ) - rpc_to_worker.shutdown(force=True) - logger.info( - f"Successfully shutdown rpc_to_worker queue for job {job_id}" - ) - else: - logger.info( - f"rpc_to_worker queue already shutdown for job {job_id}" - ) - except Exception as e: - logger.warning(f"Failed to shutdown rpc_to_worker queue: {e}") - else: - logger.warning(f"rpc_to_worker queue is None for job {job_id}") + # Update database + self.db.update_evaluation_job( + eval_job_msg.job_id, + { + "status": EvaluationStatus.FAILED, + "error_message": error, + "completed_at": completed_at, + }, + ) - logger.info(f"Completed Ray Queue actors cleanup for job {job_id}") + # Publish result + await self._publish_failure_result( + eval_job_msg=eval_job_msg, + status=EvaluationStatus.FAILED, + completed_at=completed_at, + error_message=error, + ) + # Teardown + try: + await executor.teardown(task_context) except Exception as e: - logger.warning(f"Failed to cleanup queues: {e}") + logger.warning("Teardown failed for job %s: %s", job_context["job_id"], e) - def _log_cached_runner_output( - self, job_id: Any, pod_logs: Optional[Dict[str, Any]] + async def _update_job_status( + self, + eval_job_msg: EvalJobMessage, + status: EvaluationStatus, + detail: str, ) -> None: - """Log runner pod output for setup failures to aid debugging.""" + """Update job status in database and notify backend via WebSocket.""" + try: + update_fields: Dict[str, Any] = {"status": status} + if status == EvaluationStatus.STARTING: + update_fields["started_at"] = datetime.now(timezone.utc) + update_fields["error_message"] = None + update_fields["completed_at"] = None - if not pod_logs: + self.db.update_evaluation_job(eval_job_msg.job_id, update_fields) + except Exception as e: + logger.error("Failed to update job status to %s: %s", status, e) return - containers = pod_logs.get("containers") or {} - if not containers: - error_detail = pod_logs.get("error") or pod_logs.get("warning") - if error_detail: - logger.error( - "Runner logs unavailable for job %s: %s", - job_id, - error_detail, - ) - return + # Create and send status message via WebSocket + status_msg = JobStatusUpdateMessage( + job_id=eval_job_msg.job_id, + validator_hotkey=self.keypair.ss58_address, + status=status, + detail=detail, + ) + + try: + if self.backend_client: + await self.backend_client.send_status_update(status_msg) + except Exception as e: + logger.error("Failed to send job status update: %s", e) - source_note = ( - "cached before deletion" - if pod_logs.get("cached_before_deletion") - else "live fetch" + async def _publish_success_result( + self, + eval_job_msg: EvalJobMessage, + result: TaskResult, + completed_at: datetime, + ) -> None: + """Publish a success result to the backend via WebSocket.""" + metrics = result.metrics + + # Extract benchmark spec + spec_payload = ( + copy.deepcopy(eval_job_msg.benchmark_spec) + if eval_job_msg.benchmark_spec + else {} ) + base_config = spec_payload.get("config", {}) - for container_name, entry in containers.items(): - log_text = entry.get("log") - if log_text: - logger.error( - "Runner logs for job %s (%s, %s):\n%s", - job_id, - container_name, - source_note, - log_text, - ) - elif entry.get("error"): - logger.error( - "Runner logs for job %s (%s, %s) unavailable: %s", - job_id, - container_name, - source_note, - entry.get("error"), - ) + eval_result_msg = EvalResultMessage( + job_id=eval_job_msg.job_id, + status=EvaluationStatus.COMPLETED, + validator_hotkey=self.keypair.ss58_address, + miner_hotkey=eval_job_msg.miner_hotkey, + competition_id=eval_job_msg.competition_id, + env_provider=eval_job_msg.env_provider, + benchmark_name=eval_job_msg.benchmark_name, + config=base_config, + benchmark_spec=spec_payload, + score=metrics.get("avg_reward", 0.0), + success_rate=metrics.get("success_rate"), + avg_reward=metrics.get("avg_reward"), + total_episodes=result.total_episodes, + logs=result.logs or "Evaluation completed successfully", + error=None, + extra_data={ + "summary": { + "metrics": metrics, + "duration_seconds": result.duration_seconds, + "completed_at": completed_at.isoformat(), + } + }, + ) + + # Send via WebSocket + if self.backend_client: + await self.backend_client.send_result(eval_result_msg) async def _publish_failure_result( self, - *, eval_job_msg: EvalJobMessage, status: EvaluationStatus, completed_at: datetime, - summary: Dict[str, Any], error_message: str, - log_artifact: Optional[Dict[str, Any]], - pod_logs: Optional[Dict[str, Any]], ) -> None: - """Send an evaluation result message for failed or timed-out jobs.""" - - spec_payload, base_config = self._extract_job_spec_payloads(eval_job_msg) - env_specs_payload = self._build_env_specs_payload( - eval_job_msg=eval_job_msg, env_results=None + """Publish a failure result to the backend via WebSocket.""" + spec_payload = ( + copy.deepcopy(eval_job_msg.benchmark_spec) + if eval_job_msg.benchmark_spec + else {} ) - extra_data: Dict[str, Any] = {"summary": summary} - extra_data.setdefault("benchmark_spec", copy.deepcopy(spec_payload)) - - logs_message = f"Evaluation {status.value.lower()}" - if error_message: - logs_message = f"{logs_message}: {error_message}" - - if log_artifact: - extra_data["log_artifact"] = log_artifact - artifact_ref = log_artifact.get("public_url") or log_artifact.get( - "object_key" - ) - if artifact_ref: - logs_message = f"{logs_message}. Log bundle: {artifact_ref}" - elif pod_logs: - extra_data["pod_logs"] = pod_logs - logs_message = f"{logs_message}. Container logs attached to result payload." + base_config = spec_payload.get("config", {}) eval_result_msg = EvalResultMessage( job_id=eval_job_msg.job_id, @@ -918,602 +723,58 @@ async def _publish_failure_result( success_rate=None, avg_reward=None, total_episodes=None, - env_specs=env_specs_payload or None, - logs=logs_message, + logs=f"Evaluation {status.value.lower()}: {error_message}", error=error_message, - extra_data=extra_data, ) - await self.db.queue_evaluation_result_msg(eval_result_msg) - - async def monitor_job(self, job_context: Dict): - """Monitor a running job and handle completion.""" - job_id = job_context["job_id"] - submission_id = job_context["submission_id"] - eval_job_msg = job_context["eval_job_msg"] - evaluation_future = job_context["evaluation_future"] - - try: - # Use ray.wait with timeout to check if job is done without blocking - ready, not_ready = ray.wait( - [evaluation_future], timeout=RAY_WAIT_TIMEOUT.total_seconds() - ) - - if ready: - # Job completed, get results - results: List[EnvResult] = ray.get(evaluation_future) - - logger.info( - f"Evaluation completed for job {job_id} with {len(results)} results" - ) - - # Calculate metrics - if results: - total_episodes = sum(len(result.episodes) for result in results) - if total_episodes == 0: - total_episodes = None - avg_success_rate = sum( - result.success_rate for result in results - ) / len(results) - avg_reward = sum(result.mean_reward for result in results) / len( - results - ) - - logger.info(f"Job {job_id} - Total episodes: {total_episodes or 0}") - logger.info( - f"Job {job_id} - Average success rate: {avg_success_rate:.3f}" - ) - logger.info(f"Job {job_id} - Average reward: {avg_reward:.3f}") - else: - total_episodes = None - avg_success_rate = 0.0 - avg_reward = 0.0 - - completed_at = datetime.now(timezone.utc) - summary_data: Dict[str, Any] = { - "results_count": len(results), - "total_episodes": total_episodes, - "avg_success_rate": avg_success_rate, - "avg_reward": avg_reward, - "completed_at": completed_at.isoformat(), - } - start_time = job_context.get("start_time") - if isinstance(start_time, datetime): - summary_data["started_at"] = start_time.astimezone( - timezone.utc - ).isoformat() - summary_data["duration_seconds"] = ( - completed_at - start_time - ).total_seconds() - - containers = Containers() - pod_logs: Optional[Dict[str, Any]] = None - try: - pod_logs = containers.collect_container_logs( - submission_id, job_id, tail_lines=POD_LOG_TAIL_LINES - ) - except Exception as log_exc: - logger.warning( - "Failed to collect pod logs for submission %s (job %s): %s", - submission_id, - job_id, - log_exc, - ) - - log_artifact = await self._upload_job_log_bundle( - job_context=job_context, - eval_job_msg=eval_job_msg, - status=EvaluationStatus.COMPLETED, - summary=summary_data, - completed_at=completed_at, - error=None, - pod_logs=pod_logs, - ) - - # Update database - try: - self.db.update_evaluation_job( - job_id, - { - "status": EvaluationStatus.COMPLETED, - "completed_at": completed_at, - }, - ) - except Exception as e: - logger.error(f"Failed to update job status for job {job_id}: {e}") - - # Queue result message - logs_message = "Evaluation completed successfully" - if log_artifact: - artifact_ref = log_artifact.get("public_url") or log_artifact.get( - "object_key" - ) - if artifact_ref: - logs_message = f"{logs_message}. Log bundle: {artifact_ref}" - extra_data: Dict[str, Any] = {"summary": summary_data} - if log_artifact: - extra_data["log_artifact"] = log_artifact - elif pod_logs: - extra_data["pod_logs"] = pod_logs - logs_message = ( - f"{logs_message}. Container logs attached to result payload." - ) - - spec_payload, base_config = self._extract_job_spec_payloads( - eval_job_msg - ) - extra_data.setdefault("benchmark_spec", copy.deepcopy(spec_payload)) - env_specs_payload = self._build_env_specs_payload( - eval_job_msg=eval_job_msg, env_results=results - ) - - eval_result_msg = EvalResultMessage( - job_id=job_id, - status=EvaluationStatus.COMPLETED, - validator_hotkey=self.keypair.ss58_address, - miner_hotkey=eval_job_msg.miner_hotkey, - competition_id=eval_job_msg.competition_id, - env_provider=eval_job_msg.env_provider, - benchmark_name=eval_job_msg.benchmark_name, - config=base_config, - benchmark_spec=spec_payload, - score=avg_reward, - success_rate=avg_success_rate, - avg_reward=avg_reward, - total_episodes=total_episodes, - env_specs=env_specs_payload or None, - logs=logs_message, - error=None, - extra_data=extra_data, - ) - await self.db.queue_evaluation_result_msg(eval_result_msg) - - # Clean up Ray worker and container resources - try: - # Clean up queues first - self._cleanup_queues(job_context) - - # Clean up Ray worker - cluster = job_context.get("cluster") - worker = job_context.get("worker") - if cluster and worker: - # Call cleanup on the worker before killing it - try: - ray.get(worker.cleanup.remote(), timeout=5) - except Exception as e: - logger.warning(f"Worker cleanup failed: {e}") - cluster.delete_worker(worker) - logger.info(f"Cleaned up Ray worker for job {job_id}") - - # Then clean up container - containers.cleanup_container(submission_id, job_id) - logger.info( - "Cleaned up container resources for submission %s (job %s)", - submission_id, - job_id, - ) - - # Clear references - del results - gc.collect() - except Exception as e: - logger.error( - f"Failed to clean up resources for submission {submission_id}: {e}" - ) - - return True # Job completed - - # Check for timeout - elapsed = ( - datetime.now(timezone.utc) - job_context["start_time"] - ).total_seconds() - timeout_seconds = job_context.get( - "timeout_seconds", self.default_job_timeout - ) - if elapsed > timeout_seconds: - logger.error(f"Job {job_id} timed out after {elapsed} seconds") - ray.cancel(evaluation_future) - completed_at = datetime.now(timezone.utc) - timeout_detail = ( - f"Job timed out after {elapsed:.1f} seconds " - f"(limit {timeout_seconds}s)" - ) - summary_data: Dict[str, Any] = { - "elapsed_seconds": elapsed, - "timeout_seconds": timeout_seconds, - "completed_at": completed_at.isoformat(), - } - start_time = job_context.get("start_time") - if isinstance(start_time, datetime): - summary_data["started_at"] = start_time.astimezone( - timezone.utc - ).isoformat() - summary_data["duration_seconds"] = ( - completed_at - start_time - ).total_seconds() - - containers = Containers() - pod_logs: Optional[Dict[str, Any]] = None - try: - pod_logs = containers.collect_container_logs( - submission_id, job_id, tail_lines=POD_LOG_TAIL_LINES - ) - except Exception as log_exc: - logger.warning( - "Failed to collect pod logs for timed out submission %s (job %s): %s", - submission_id, - job_id, - log_exc, - ) - - log_artifact = await self._upload_job_log_bundle( - job_context=job_context, - eval_job_msg=eval_job_msg, - status=EvaluationStatus.TIMEOUT, - summary=summary_data, - completed_at=completed_at, - error=timeout_detail, - pod_logs=pod_logs, - ) - - error_message = timeout_detail - if log_artifact: - artifact_ref = log_artifact.get("public_url") or log_artifact.get( - "object_key" - ) - if artifact_ref: - error_message = f"{error_message}. Log bundle: {artifact_ref}" - elif pod_logs: - error_message = ( - f"{error_message}. Container logs attached to result payload." - ) - - await self._publish_failure_result( - eval_job_msg=eval_job_msg, - status=EvaluationStatus.TIMEOUT, - completed_at=completed_at, - summary=summary_data, - error_message=timeout_detail, - log_artifact=log_artifact, - pod_logs=pod_logs, - ) - - self.db.update_evaluation_job( - job_id, - { - "status": EvaluationStatus.TIMEOUT, - "error_message": error_message, - "completed_at": completed_at, - }, - ) - try: - status_msg = JobStatusUpdateMessage( - job_id=eval_job_msg.job_id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.TIMEOUT, - detail=error_message, - ) - await self.db.queue_job_status_update_msg(status_msg) - except Exception as status_err: - logger.warning( - "Failed to queue timeout status update for job %s: %s", - eval_job_msg.job_id, - status_err, - ) - - # Clean up Ray worker and container resources on timeout - try: - # Clean up queues first - self._cleanup_queues(job_context) - - # Clean up Ray worker - cluster = job_context.get("cluster") - worker = job_context.get("worker") - if cluster and worker: - # Try to call cleanup on the worker before killing it - try: - ray.get(worker.cleanup.remote(), timeout=2) - except Exception as e: - logger.warning(f"Worker cleanup failed on timeout: {e}") - cluster.delete_worker(worker) - logger.info(f"Cleaned up Ray worker for timed out job {job_id}") - - # Then clean up container - containers.cleanup_container(submission_id, job_id) - logger.info( - "Cleaned up container resources for timed out submission %s (job %s)", - submission_id, - job_id, - ) - - gc.collect() - except Exception as e: - logger.error( - f"Failed to clean up resources for timed out submission {submission_id}: {e}" - ) - - return True # Job timed out - - return False # Job still running - - except Exception as e: - logger.error(f"Error monitoring job {job_id}: {e}") - completed_at = datetime.now(timezone.utc) - error_detail = str(e) - summary_data: Dict[str, Any] = { - "exception_type": type(e).__name__, - "message": error_detail, - "completed_at": completed_at.isoformat(), - } - start_time = job_context.get("start_time") - if isinstance(start_time, datetime): - summary_data["started_at"] = start_time.astimezone( - timezone.utc - ).isoformat() - summary_data["duration_seconds"] = ( - completed_at - start_time - ).total_seconds() - - containers = Containers() - pod_logs: Optional[Dict[str, Any]] = None - try: - pod_logs = containers.collect_container_logs( - submission_id, job_id, tail_lines=POD_LOG_TAIL_LINES - ) - except Exception as log_exc: - logger.warning( - "Failed to collect pod logs for failed submission %s (job %s): %s", - submission_id, - job_id, - log_exc, - ) - - log_artifact = await self._upload_job_log_bundle( - job_context=job_context, - eval_job_msg=eval_job_msg, - status=EvaluationStatus.FAILED, - summary=summary_data, - completed_at=completed_at, - error=error_detail, - pod_logs=pod_logs, - ) - - error_message = error_detail - if log_artifact: - artifact_ref = log_artifact.get("public_url") or log_artifact.get( - "object_key" - ) - if artifact_ref: - error_message = f"{error_message}. Log bundle: {artifact_ref}" - elif pod_logs: - error_message = ( - f"{error_message}. Container logs attached to result payload." - ) - - await self._publish_failure_result( - eval_job_msg=eval_job_msg, - status=EvaluationStatus.FAILED, - completed_at=completed_at, - summary=summary_data, - error_message=error_detail, - log_artifact=log_artifact, - pod_logs=pod_logs, - ) - - self.db.update_evaluation_job( - job_id, - { - "status": EvaluationStatus.FAILED, - "error_message": error_message, - "completed_at": completed_at, - }, - ) - - # Clean up Ray worker and container resources on error - try: - # Clean up queues first - self._cleanup_queues(job_context) - - # Clean up Ray worker - cluster = job_context.get("cluster") - worker = job_context.get("worker") - if cluster and worker: - # Try to call cleanup on the worker before killing it - try: - ray.get(worker.cleanup.remote(), timeout=2) - except Exception as e: - logger.warning(f"Worker cleanup failed on error: {e}") - cluster.delete_worker(worker) - logger.info(f"Cleaned up Ray worker for failed job {job_id}") - - # Then clean up container - containers.cleanup_container(submission_id, job_id) - logger.info( - "Cleaned up container resources for failed submission %s (job %s)", - submission_id, - job_id, - ) - - gc.collect() - except Exception as ex: - logger.error( - f"Failed to clean up resources for failed submission {submission_id}: {ex}" - ) + # Send via WebSocket + if self.backend_client: + await self.backend_client.send_result(eval_result_msg) - return True # Job failed + async def process_job(self, eval_job_msg: EvalJobMessage) -> None: + """Process a job asynchronously.""" + job_id = eval_job_msg.job_id - async def process_job(self, job: Job): - """Process a job asynchronously without blocking.""" if self.concurrent_slots.locked(): logger.warning( - f"Max concurrent jobs ({self.max_concurrent_jobs}) reached. Job {getattr(job, 'id', 'unknown')} waiting for a free slot." + "Max concurrent jobs (%s) reached. Job %s waiting for a free slot.", + self.max_concurrent_jobs, + job_id, ) await self.concurrent_slots.acquire() try: - job_context = await self.setup_job(job) + job_context = await self.setup_job(eval_job_msg) except PodSchedulingError as e: - job_id = getattr(job, "id", "unknown") - logger.warning("Insufficient cluster resources for job %s: %s", job_id, e) - - if job.payload: - try: - eval_job_msg = EvalJobMessage.from_bytes(job.payload) - backoff_detail = ( - f"Waiting for cluster capacity: {e}" if str(e) else None - ) - self.db.update_evaluation_job( - eval_job_msg.job_id, - { - "status": EvaluationStatus.QUEUED, - "error_message": backoff_detail, - "started_at": None, - "completed_at": None, - }, - ) - - status_msg = JobStatusUpdateMessage( - job_id=eval_job_msg.job_id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.QUEUED, - detail=backoff_detail, - ) - await self.db.queue_job_status_update_msg(status_msg) - - logger.info( - "Requeueing job %s after resource backoff", eval_job_msg.job_id - ) - await asyncio.sleep(RESOURCE_BACKOFF_SECONDS.total_seconds()) - await self._enqueue_job_for_processing(eval_job_msg) - except Exception as requeue_err: - logger.error( - "Failed to requeue job %s after resource shortfall: %s", - job_id, - requeue_err, - ) - + logger.warning( + "Insufficient cluster resources for job %s: %s", + job_id, + e, + ) self.concurrent_slots.release() return except Exception as e: - job_id = getattr(job, "id", "unknown") - logger.error(f"Failed to process job {job_id}: {e}") - - completed_at = datetime.now(timezone.utc) - eval_job_msg: Optional[EvalJobMessage] = None - pod_logs: Optional[Dict[str, Any]] = None - log_artifact: Optional[Dict[str, Any]] = None - error_detail = str(e) - - if job.payload: - try: - eval_job_msg = EvalJobMessage.from_bytes(job.payload) - except Exception as decode_err: - logger.error( - "Failed to decode job payload for job %s: %s", - job_id, - decode_err, - ) - - if eval_job_msg is not None: - containers = Containers() - try: - pod_logs = containers.collect_container_logs( - eval_job_msg.submission_id, - eval_job_msg.job_id, - tail_lines=POD_LOG_TAIL_LINES, - ) - except Exception as log_exc: - logger.warning( - "Failed to collect pod logs for setup failure %s: %s", - eval_job_msg.job_id, - log_exc, - ) - else: - self._log_cached_runner_output(eval_job_msg.job_id, pod_logs) - - summary_data: Dict[str, Any] = { - "exception_type": type(e).__name__, - "message": error_detail, - "stage": "setup_job", - "completed_at": completed_at.isoformat(), - } - stub_context = {"job_id": eval_job_msg.job_id, "start_time": None} - log_artifact = await self._upload_job_log_bundle( - job_context=stub_context, - eval_job_msg=eval_job_msg, - status=EvaluationStatus.FAILED, - summary=summary_data, - completed_at=completed_at, - error=error_detail, - pod_logs=pod_logs, - ) - - await self._publish_failure_result( - eval_job_msg=eval_job_msg, - status=EvaluationStatus.FAILED, - completed_at=completed_at, - summary=summary_data, - error_message=error_detail, - log_artifact=log_artifact, - pod_logs=pod_logs, - ) - - # Attempt to mark the evaluation job as failed and notify backend - try: - if eval_job_msg is not None: - failure_detail = f"Container setup failed: {e}" - if log_artifact: - artifact_ref = log_artifact.get( - "public_url" - ) or log_artifact.get("object_key") - if artifact_ref: - failure_detail = ( - f"{failure_detail}. Log bundle: {artifact_ref}" - ) - elif pod_logs: - failure_detail = f"{failure_detail}. Container logs attached to result payload." - - self.db.update_evaluation_job( - eval_job_msg.job_id, - { - "status": EvaluationStatus.FAILED, - "error_message": failure_detail, - "completed_at": completed_at, - }, - ) - - status_msg = JobStatusUpdateMessage( - job_id=eval_job_msg.job_id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.FAILED, - detail=failure_detail, - ) - await self.db.queue_job_status_update_msg(status_msg) - except Exception as status_err: - logger.error( - "Failed to publish failure status for job %s: %s", - job_id, - status_err, - ) - + logger.error("Failed to process job %s: %s", job_id, e) self.concurrent_slots.release() return if not job_context: logger.warning( - f"Setup for job {getattr(job, 'id', 'unknown')} returned no context; releasing slot." + "Setup for job %s returned no context; releasing slot.", + job_id, ) self.concurrent_slots.release() return - job_id = job_context["job_id"] self.running_jobs[job_id] = job_context logger.info( - f"Job {job_id} added to running jobs. Total running: {len(self.running_jobs)}" + "Job %s added to running jobs. Total running: %s", + job_id, + len(self.running_jobs), ) - async def monitor_running_jobs(self): + async def monitor_running_jobs(self) -> None: """Background task to monitor all running jobs.""" while True: if self.running_jobs: @@ -1528,23 +789,24 @@ async def monitor_running_jobs(self): job_context = self.running_jobs.pop(job_id, None) if job_context is not None: self.concurrent_slots.release() - # Clear references to heavy objects job_context.clear() logger.info( - f"Job {job_id} removed from running jobs. Remaining: {len(self.running_jobs)}" + "Job %s removed from running jobs. Remaining: %s", + job_id, + len(self.running_jobs), ) - await asyncio.sleep(1) # Check every second + await asyncio.sleep(1) - async def recover_stale_jobs(self): + async def recover_stale_jobs(self) -> None: """Recover jobs that were running when orchestrator crashed.""" - logger.info("Checking for stale jobs from previous orchestrator run...") + logger.info("Checking for stale jobs from previous run...") - # Find jobs that are stuck in STARTING or RUNNING state stale_statuses = [EvaluationStatus.STARTING, EvaluationStatus.RUNNING] for status in stale_statuses: stale_jobs = self.db.get_evaluation_jobs_by_status(status) - logger.info(f"Found {len(stale_jobs)} jobs in {status.value} state") + logger.info("Found %s jobs in %s state", len(stale_jobs), status.value) + for job in stale_jobs: logger.info( "Resetting stale job %s (previous status %s) for restart", @@ -1552,25 +814,19 @@ async def recover_stale_jobs(self): status.value, ) try: + # Clean up any orphaned containers if job.submission_id: try: containers = Containers() containers.cleanup_container( job.submission_id, job.id, wait=True ) - logger.info( - "Cleaned up orphaned container for submission %s (job %s)", - job.submission_id, - job.id, - ) - except Exception as cleanup_err: + except Exception as e: logger.warning( - "Failed to cleanup container for submission %s (job %s): %s", - job.submission_id, - job.id, - cleanup_err, + "Failed to cleanup container for job %s: %s", job.id, e ) + # Reset job status self.db.update_evaluation_job( job.id, { @@ -1581,68 +837,22 @@ async def recover_stale_jobs(self): }, ) - status_msg = JobStatusUpdateMessage( - job_id=job.id, - validator_hotkey=self.keypair.ss58_address, - status=EvaluationStatus.QUEUED, - detail="Evaluator reset job after orchestrator restart", - ) - try: - await self.db.queue_job_status_update_msg(status_msg) - except Exception as status_err: - logger.warning( - "Failed to queue status update for job %s: %s", - job.id, - status_err, - ) - - job_config_payload = ( - copy.deepcopy(job.config) - if isinstance(job.config, dict) - else {} - ) - timeout_seconds = self._resolve_job_timeout( - getattr(job, "timeout_seconds", None), - self.default_job_timeout, - ) - - spec_payload, base_config = self._split_config_data( - job_config_payload - ) - - eval_job_msg = EvalJobMessage( - job_id=job.id, - competition_id=job.competition_id, - submission_id=job.submission_id, - miner_hotkey=job.miner_hotkey, - hf_repo_id=job.hf_repo_id, - env_provider=job.env_provider, - benchmark_name=job.benchmark_name, - config=base_config, - benchmark_spec=spec_payload, - artifact_url=job.artifact_url, - artifact_expires_at=job.artifact_expires_at, - artifact_sha256=job.artifact_sha256, - artifact_size_bytes=job.artifact_size_bytes, - timeout_seconds=timeout_seconds, - ) + logger.info("Stale job %s reset to QUEUED", job.id) - await self._enqueue_job_for_processing(eval_job_msg) - logger.info("Stale job %s successfully reset and requeued", job.id) - except Exception as exc: - logger.error("Failed to reset stale job %s: %s", job.id, exc) + except Exception as e: + logger.error("Failed to reset stale job %s: %s", job.id, e) self.db.update_evaluation_job( job.id, { "status": EvaluationStatus.FAILED, - "error_message": f"Failed to restart after orchestrator crash: {exc}", + "error_message": f"Failed to restart: {e}", "completed_at": datetime.now(timezone.utc), }, ) logger.info("Stale job recovery completed") - async def periodic_cleanup(self): + async def periodic_cleanup(self) -> None: """Periodic cleanup of orphaned containers and resources.""" while True: try: @@ -1651,127 +861,113 @@ async def periodic_cleanup(self): containers = Containers() - # Get all completed/failed jobs from the last 24 hours - # that might have orphaned containers - failed_jobs = self.db.get_evaluation_jobs_by_status( - EvaluationStatus.FAILED - ) - timeout_jobs = self.db.get_evaluation_jobs_by_status( - EvaluationStatus.TIMEOUT - ) - completed_jobs = self.db.get_evaluation_jobs_by_status( - EvaluationStatus.COMPLETED - ) - - # Clean up containers for old completed/failed jobs - for job_list in [failed_jobs, timeout_jobs, completed_jobs]: - for job in job_list: - if job.completed_at: - job_timeout_seconds = self._resolve_job_timeout( + # Get completed/failed jobs and clean up old containers + for status in [ + EvaluationStatus.FAILED, + EvaluationStatus.TIMEOUT, + EvaluationStatus.COMPLETED, + ]: + jobs = self.db.get_evaluation_jobs_by_status(status) + for job in jobs: + if job.completed_at and job.submission_id: + job_timeout = self._resolve_job_timeout( getattr(job, "timeout_seconds", None), self.default_job_timeout, ) - - # Ensure both datetimes have timezone info for comparison - current_time = datetime.now(timezone.utc) completed_time = job.completed_at if completed_time.tzinfo is None: completed_time = completed_time.replace( tzinfo=timezone.utc ) - if ( - current_time - completed_time - ).total_seconds() > job_timeout_seconds: + age = ( + datetime.now(timezone.utc) - completed_time + ).total_seconds() + if age > job_timeout: try: - if job.submission_id: - containers.cleanup_container( - job.submission_id, job.id - ) - logger.debug( - "Cleaned up container for old job %s (submission %s)", - job.id, - job.submission_id, - ) - except Exception as e: - logger.debug( - f"Container cleanup failed for job {job.id}: {e}" + containers.cleanup_container( + job.submission_id, job.id ) + except Exception: + pass gc.collect() - logger.info("Periodic cleanup completed") + except Exception as e: - logger.error(f"Error during periodic cleanup: {e}") + logger.error("Error during periodic cleanup: %s", e) - async def start(self): - logger.info("Starting orchestrator...") + async def start(self) -> None: + """Start the orchestrator with direct backend connection.""" + logger.info("Starting Orchestrator...") - # Recover any stale jobs from previous run + # Recover stale jobs await self.recover_stale_jobs() - conn = await asyncpg.connect(dsn=self.config.pg_database) + # Initialize the backend client + self.backend_client = self._init_backend_client() - driver = AsyncpgDriver(conn) - pgq = PgQueuer(driver) - - # Start the job monitoring task + # Start background tasks monitor_task = asyncio.create_task(self.monitor_running_jobs()) logger.info("Started job monitoring task") - # Start periodic cleanup task cleanup_task = asyncio.create_task(self.periodic_cleanup()) logger.info("Started periodic cleanup task") - @pgq.entrypoint("add_job") - async def process(job: Job) -> None: - asyncio.create_task(self.process_job(job)) - logger.info(f"Job {job.id} added to processing queue.") - logger.info( - f"Orchestrator is now listening for jobs (max concurrent: {self.max_concurrent_jobs})..." + "Orchestrator connecting to backend (max concurrent: %s)...", + self.max_concurrent_jobs, ) try: - await pgq.run() + # Connect and run - this blocks until stop() is called + await self.backend_client.connect_and_run() finally: monitor_task.cancel() cleanup_task.cancel() + try: + await monitor_task + except asyncio.CancelledError: + pass + try: + await cleanup_task + except asyncio.CancelledError: + pass - await asyncio.Future() + async def stop(self) -> None: + """Stop the orchestrator.""" + logger.info("Stopping Orchestrator...") + + # Stop backend client + if self.backend_client: + try: + await self.backend_client.stop() + logger.info("Backend client stopped") + except Exception as e: + logger.error("Error stopping backend client: %s", e) + self.backend_client = None - def stop(self): - logger.info("Stopping orchestrator...") # Clean up all running jobs for job_id in list(self.running_jobs.keys()): job_context = self.running_jobs.pop(job_id, None) if job_context is None: continue - try: - # Clean up queues first - self._cleanup_queues(job_context) - cluster = job_context.get("cluster") - worker = job_context.get("worker") - if cluster and worker: - # Try to call cleanup on the worker before killing it + try: + executor = job_context.get("executor") + task_context = job_context.get("task_context") + if executor and task_context: try: - ray.get(worker.cleanup.remote(), timeout=2) + await executor.teardown(task_context) except Exception as e: - logger.warning(f"Worker cleanup failed during shutdown: {e}") - cluster.delete_worker(worker) - submission_id = job_context.get("submission_id") - if submission_id: - containers = Containers() - containers.cleanup_container(submission_id, job_id) + logger.error("Error tearing down job %s: %s", job_id, e) except Exception as e: - logger.error(f"Error cleaning up job {job_id}: {e}") + logger.error("Error cleaning up job %s: %s", job_id, e) finally: self.concurrent_slots.release() self.running_jobs.clear() - # Shutdown Ray if initialized if ray.is_initialized(): ray.shutdown() logger.info("Ray shutdown complete") diff --git a/src/evaluator/providers/metaworld_provider.py b/src/evaluator/providers/metaworld_provider.py new file mode 100644 index 0000000..2d3fee3 --- /dev/null +++ b/src/evaluator/providers/metaworld_provider.py @@ -0,0 +1,177 @@ +""" +Metaworld environment provider implementation. + +Provides MetaWorld MT1/MT10/MT50 and ML10/ML45 benchmarks. +""" + +import logging +import random +from typing import Any, Dict, List, Type + +import gymnasium as gym +from gymnasium import ObservationWrapper +from metaworld.wrappers import OneHotWrapper + +from .metaworld import load_benchmark_definition +from .registry import BenchmarkSpec, EnvironmentProvider, EnvSpec + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_EPISODE_STEPS = 10 +DEFAULT_EPISODES_PER_TASK = 3 +DEFAULT_TASKS_PER_ENV = 5 + + +class MetaworldProvider(EnvironmentProvider): + """ + Metaworld environment provider. + + Provides access to MetaWorld benchmarks: + - MT1, MT10, MT25, MT50 (multi-task) + - ML10, ML25, ML45 (meta-learning) + """ + + @property + def name(self) -> str: + return "metaworld" + + def get_benchmark_specs(self, config: Dict[str, Any]) -> List[BenchmarkSpec]: + """Get benchmark specifications from config.""" + benchmark_name = config.get("benchmark_name", "MT1") + return [ + BenchmarkSpec( + provider=self.name, + benchmark_name=benchmark_name, + config=config, + render_mode=config.get("render_mode", "rgb_array"), + camera_names=tuple(config.get("camera_names", ["corner"])), + camera_attribute=config.get("camera_attribute", "camera_name"), + ) + ] + + def get_env_specs(self, benchmark_spec: BenchmarkSpec) -> List[EnvSpec]: + """Get all MetaWorld test environments and tasks for the specified benchmark.""" + tasks_per_env = int( + benchmark_spec.config.get("tasks_per_env", DEFAULT_TASKS_PER_ENV) + ) + if tasks_per_env < 1: + raise ValueError("tasks_per_env must be >= 1") + + task_seed = benchmark_spec.config.get("task_seed") + if task_seed is None: + task_seed = random.randint(0, 2**31 - 1) + benchmark_spec.config["task_seed"] = task_seed + + env_name_override = benchmark_spec.config.get("env_name") + + benchmark_data = load_benchmark_definition( + benchmark_spec.benchmark_name, + tasks_per_env=tasks_per_env, + env_name=env_name_override, + seed=task_seed, + ) + + if benchmark_spec.benchmark_name in {"MT1", "MT10", "MT25", "MT50"}: + class_lookup = benchmark_data.train_classes + task_source = benchmark_data.train_tasks + elif benchmark_spec.benchmark_name in {"ML10", "ML25", "ML45"}: + class_lookup = benchmark_data.test_classes + task_source = benchmark_data.test_tasks + else: + raise ValueError( + f"Unsupported MetaWorld benchmark: {benchmark_spec.benchmark_name}" + ) + + env_specs: List[EnvSpec] = [] + class_order = tuple(class_lookup.keys()) + + if env_name_override and env_name_override not in class_lookup: + raise ValueError( + f"Environment '{env_name_override}' not found in benchmark {benchmark_spec.benchmark_name}" + ) + + for env_id, env_name in enumerate(class_order): + if env_name_override and env_name != env_name_override: + continue + + env_tasks = [task for task in task_source if task.env_name == env_name] + + for task_idx, task in enumerate(env_tasks): + env_spec = EnvSpec( + env_name=env_name, + benchmark_name=benchmark_spec.benchmark_name, + provider=self.name, + config={ + "task_idx": task_idx, + "task_data": task, + "env_cls": class_lookup[env_name], + "env_id": env_id, + "class_order": class_order, + "task_seed": task_seed, + }, + episodes_per_task=benchmark_spec.config.get( + "episodes_per_task", DEFAULT_EPISODES_PER_TASK + ), + max_episode_steps=benchmark_spec.config.get( + "max_episode_steps", DEFAULT_MAX_EPISODE_STEPS + ), + render_mode=benchmark_spec.render_mode, + camera_attribute=benchmark_spec.camera_attribute, + camera_names=benchmark_spec.camera_names, + ) + env_specs.append(env_spec) + + logger.info( + "Found %d test tasks across all environments for %s", + len(env_specs), + benchmark_spec, + ) + return env_specs + + def make_env( + self, + spec: EnvSpec, + submission_id: str | None = None, + save_images: bool = False, + ) -> gym.Env: + """Create a MetaWorld environment for the specified environment spec.""" + # Import here to avoid circular imports + from ..rollout.envs import MetaworldObsWrapper # noqa: PLC0415 + + config = spec.config + env_cls = config.get("env_cls") + if env_cls is None: + raise ValueError("EnvSpec config missing 'env_cls' for MetaWorld env") + + render_mode = spec.render_mode + env = env_cls(render_mode=render_mode, camera_name="corner") + + task = config.get("task_data") + if task is None: + raise ValueError("EnvSpec config missing 'task_data' for MetaWorld env") + env.set_task(task) + + class_order: tuple[str, ...] | tuple[()] = config.get("class_order", tuple()) + env_id = config.get("env_id") + if class_order and env_id is not None: + env = OneHotWrapper(env, env_id, len(class_order)) + + logger.debug("Applying MetaworldObsWrapper") + env = MetaworldObsWrapper( + env, + camera_attribute=spec.camera_attribute, + camera_names=spec.camera_names, + save_images=save_images, + image_save_dir="data" if save_images else None, + submission_id=submission_id, + ) + + env = gym.wrappers.TimeLimit(env, max_episode_steps=spec.max_episode_steps) + + return env + + def get_observation_wrapper(self) -> Type[ObservationWrapper] | None: + """Get the observation wrapper class for MetaWorld.""" + from ..rollout.envs import MetaworldObsWrapper # noqa: PLC0415 + + return MetaworldObsWrapper diff --git a/src/evaluator/providers/registry.py b/src/evaluator/providers/registry.py new file mode 100644 index 0000000..34c8de5 --- /dev/null +++ b/src/evaluator/providers/registry.py @@ -0,0 +1,265 @@ +""" +Provider registry for Kinitro evaluator. + +Provides a plugin architecture for environment providers (metaworld, swarm, etc.). +Extracted from EnvManager for better separation of concerns. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Protocol, Type, runtime_checkable + +import gymnasium as gym +from gymnasium import ObservationWrapper + +from core.log import get_logger + +logger = get_logger(__name__) + + +@dataclass +class BenchmarkSpec: + """Specification for a benchmark and its environments.""" + + provider: str # "metaworld", "swarm", etc. + benchmark_name: str # "MT1", "MT10", etc. + config: Dict[str, Any] = field(default_factory=dict) + render_mode: str | None = "rgb_array" + camera_names: tuple[str, ...] = ("corner",) + camera_attribute: str | None = "camera_name" + + def __str__(self) -> str: + return f"{self.provider}/{self.benchmark_name}" + + +@dataclass +class EnvSpec: + """Specification for a single environment instance.""" + + env_name: str + benchmark_name: str + provider: str + config: Dict[str, Any] = field(default_factory=dict) + + # Runtime controls + episodes_per_task: int = 3 + max_episode_steps: int = 10 + render_mode: str | None = "rgb_array" + + # Observation capture options + camera_attribute: str | None = "camera_name" + camera_names: tuple[str, ...] = ("corner",) + + def __str__(self) -> str: + return f"{self.provider}/{self.benchmark_name}/{self.env_name}" + + +@runtime_checkable +class EnvironmentProvider(Protocol): + """ + Interface for environment providers. + + Environment providers are responsible for: + - Providing benchmark specifications for a given config + - Creating gymnasium environments from EnvSpec + - Optionally providing observation wrappers + """ + + @property + def name(self) -> str: + """Return the unique name of this provider (e.g., 'metaworld', 'swarm').""" + ... + + def get_benchmark_specs(self, config: Dict[str, Any]) -> List[BenchmarkSpec]: + """ + Get benchmark specifications from a configuration. + + Args: + config: Provider-specific configuration + + Returns: + List of BenchmarkSpec objects + """ + ... + + def get_env_specs(self, benchmark_spec: BenchmarkSpec) -> List[EnvSpec]: + """ + Get environment specifications for a benchmark. + + Args: + benchmark_spec: The benchmark specification + + Returns: + List of EnvSpec objects for environments in this benchmark + """ + ... + + def make_env( + self, + spec: EnvSpec, + submission_id: str | None = None, + save_images: bool = False, + ) -> gym.Env: + """ + Create a gymnasium environment from an EnvSpec. + + Args: + spec: The environment specification + submission_id: Optional submission ID for logging/debugging + save_images: Whether to save rendered images + + Returns: + A gymnasium environment + """ + ... + + def get_observation_wrapper(self) -> Type[ObservationWrapper] | None: + """ + Get the observation wrapper class for this provider. + + Returns: + The observation wrapper class, or None if no wrapper is needed + """ + ... + + +class ProviderRegistry: + """ + Registry for environment providers. + + This class provides: + - Registration of provider implementations + - Lookup of providers by name + - Listing of available providers + """ + + _providers: Dict[str, EnvironmentProvider] = {} + _initialized: bool = False + + @classmethod + def register(cls, provider: EnvironmentProvider) -> None: + """ + Register an environment provider. + + Args: + provider: The provider implementation to register + + Raises: + ValueError: If a provider with this name is already registered + """ + name = provider.name + if name in cls._providers: + logger.warning(f"Provider '{name}' is already registered, overwriting") + cls._providers[name] = provider + logger.info(f"Registered environment provider: {name}") + + @classmethod + def unregister(cls, name: str) -> bool: + """ + Unregister an environment provider. + + Args: + name: The provider name to unregister + + Returns: + True if the provider was unregistered, False if not found + """ + if name in cls._providers: + del cls._providers[name] + logger.info(f"Unregistered environment provider: {name}") + return True + return False + + @classmethod + def get(cls, name: str) -> EnvironmentProvider: + """ + Get a provider by name. + + Args: + name: The provider name + + Returns: + The provider implementation + + Raises: + KeyError: If the provider is not registered + """ + if name not in cls._providers: + available = ", ".join(cls._providers.keys()) or "none" + raise KeyError( + f"Provider '{name}' not found. Available providers: {available}" + ) + return cls._providers[name] + + @classmethod + def get_optional(cls, name: str) -> Optional[EnvironmentProvider]: + """ + Get a provider by name, returning None if not found. + + Args: + name: The provider name + + Returns: + The provider implementation, or None if not found + """ + return cls._providers.get(name) + + @classmethod + def list_providers(cls) -> List[str]: + """ + List all registered provider names. + + Returns: + List of provider names + """ + return list(cls._providers.keys()) + + @classmethod + def has_provider(cls, name: str) -> bool: + """ + Check if a provider is registered. + + Args: + name: The provider name + + Returns: + True if the provider is registered + """ + return name in cls._providers + + @classmethod + def clear(cls) -> None: + """Clear all registered providers (useful for testing).""" + cls._providers.clear() + cls._initialized = False + logger.info("Cleared all environment providers") + + @classmethod + def initialize_default_providers(cls) -> None: + """ + Initialize and register the default providers (metaworld, swarm). + + This should be called once at application startup. + """ + if cls._initialized: + return + + # Import and register default providers + try: + from .metaworld_provider import MetaworldProvider # noqa: PLC0415 + + cls.register(MetaworldProvider()) + except ImportError as e: + logger.warning(f"Failed to load metaworld provider: {e}") + + try: + from .swarm_provider import SwarmProvider # noqa: PLC0415 + + cls.register(SwarmProvider()) + except ImportError as e: + logger.warning(f"Failed to load swarm provider: {e}") + + cls._initialized = True + logger.info( + f"Initialized {len(cls._providers)} default providers: " + f"{', '.join(cls.list_providers())}" + ) diff --git a/src/evaluator/providers/swarm_provider.py b/src/evaluator/providers/swarm_provider.py new file mode 100644 index 0000000..ca7b342 --- /dev/null +++ b/src/evaluator/providers/swarm_provider.py @@ -0,0 +1,166 @@ +""" +Swarm environment provider implementation. + +Provides PyBullet drone simulation environments. +""" + +import logging +import random +from typing import Any, Dict, List, Type + +import gymnasium as gym +from gymnasium import ObservationWrapper + +from . import swarm as swarm_module +from .registry import BenchmarkSpec, EnvironmentProvider, EnvSpec + +logger = logging.getLogger(__name__) + +DEFAULT_EPISODES_PER_TASK = 3 +DEFAULT_TASKS_PER_ENV = 5 + + +class SwarmProvider(EnvironmentProvider): + """ + Swarm PyBullet environment provider. + + Provides access to drone simulation environments using PyBullet. + """ + + @property + def name(self) -> str: + return "swarm" + + def get_benchmark_specs(self, config: Dict[str, Any]) -> List[BenchmarkSpec]: + """Get benchmark specifications from config.""" + benchmark_name = config.get("benchmark_name", "swarm-default") + return [ + BenchmarkSpec( + provider=self.name, + benchmark_name=benchmark_name, + config=config, + render_mode=None, # Swarm doesn't use standard render mode + camera_names=tuple(), + camera_attribute=None, + ) + ] + + def get_env_specs(self, benchmark_spec: BenchmarkSpec) -> List[EnvSpec]: + """Generate MapTask-based environments for the Swarm PyBullet provider.""" + tasks_per_env = int( + benchmark_spec.config.get("tasks_per_env", DEFAULT_TASKS_PER_ENV) + ) + if tasks_per_env < 1: + raise ValueError("tasks_per_env must be >= 1") + + task_seed = benchmark_spec.config.get("task_seed") + if task_seed is not None: + task_seed = int(task_seed) + else: + task_seed = random.randint(0, 2**31 - 1) + benchmark_spec.config["task_seed"] = task_seed + + sim_dt = float(benchmark_spec.config.get("sim_dt", swarm_module.SIM_DT)) + horizon = float(benchmark_spec.config.get("horizon", swarm_module.HORIZON_SEC)) + gui = bool(benchmark_spec.config.get("gui", False)) + env_name = benchmark_spec.config.get("env_name", "swarm-moving-drone") + episodes_per_task = int( + benchmark_spec.config.get("episodes_per_task", DEFAULT_EPISODES_PER_TASK) + ) + max_episode_steps_override = benchmark_spec.config.get("max_episode_steps") + payload_mode = bool(benchmark_spec.config.get("payload_mode", False)) + challenge_type = benchmark_spec.config.get("challenge_type") + + if sim_dt <= 0: + raise ValueError("sim_dt must be > 0 for Swarm provider") + if horizon <= 0: + raise ValueError("horizon must be > 0 for Swarm provider") + if max_episode_steps_override is not None: + if int(max_episode_steps_override) < 1: + raise ValueError("max_episode_steps must be >= 1 for Swarm provider") + + env_specs: List[EnvSpec] = [] + for task_idx in range(tasks_per_env): + seed = task_seed + task_idx if task_seed is not None else None + task = swarm_module.random_task( + sim_dt, + horizon, + seed=seed, + payload=payload_mode, + challenge_type=challenge_type, + ) + + max_episode_steps = ( + int(max_episode_steps_override) + if max_episode_steps_override is not None + else int(max(1, round(task.horizon / task.sim_dt))) + ) + + env_specs.append( + EnvSpec( + env_name=env_name, + benchmark_name=benchmark_spec.benchmark_name, + provider=self.name, + config={ + "task": task, + "task_idx": task_idx, + "task_seed": seed, + "gui": gui, + "payload_mode": payload_mode, + "challenge_type": challenge_type, + }, + episodes_per_task=episodes_per_task, + max_episode_steps=max_episode_steps, + render_mode=None, + camera_attribute=None, + camera_names=tuple(), + ) + ) + + logger.info( + "Generated %d Swarm tasks (sim_dt=%.4f, horizon=%.2f) for %s", + len(env_specs), + sim_dt, + horizon, + benchmark_spec, + ) + return env_specs + + def make_env( + self, + spec: EnvSpec, + submission_id: str | None = None, + save_images: bool = False, + ) -> gym.Env: + """Create a Swarm PyBullet environment from the provided task config.""" + # Import here to avoid circular imports + from ..rollout.envs import DroneObsWrapper # noqa: PLC0415 + + config = spec.config + task = config.get("task") + seed = config.get("task_seed") + challenge_type = config.get("challenge_type") + + if task is None: + task = swarm_module.random_task( + sim_dt=swarm_module.SIM_DT, + horizon=swarm_module.HORIZON_SEC, + seed=seed, + payload=bool(config.get("payload_mode", False)), + challenge_type=challenge_type, + ) + + gui = bool(config.get("gui", False)) + env = swarm_module.make_env(task, gui=gui) + env = DroneObsWrapper(env) + + if spec.max_episode_steps: + env = gym.wrappers.TimeLimit(env, max_episode_steps=spec.max_episode_steps) + + return env + + def get_observation_wrapper(self) -> Type[ObservationWrapper] | None: + """Get the observation wrapper class for Swarm.""" + from ..rollout.envs import DroneObsWrapper # noqa: PLC0415 + + return DroneObsWrapper diff --git a/src/evaluator/providers/tests/__init__.py b/src/evaluator/providers/tests/__init__.py new file mode 100644 index 0000000..063e680 --- /dev/null +++ b/src/evaluator/providers/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for provider registry and providers.""" diff --git a/src/evaluator/providers/tests/test_registry.py b/src/evaluator/providers/tests/test_registry.py new file mode 100644 index 0000000..3ced958 --- /dev/null +++ b/src/evaluator/providers/tests/test_registry.py @@ -0,0 +1,178 @@ +"""Tests for ProviderRegistry component.""" + +from unittest.mock import MagicMock + +import pytest + +from evaluator.providers.registry import ( + BenchmarkSpec, + EnvSpec, + ProviderRegistry, +) + + +class MockProvider: + """Mock provider for testing.""" + + def __init__(self, name: str = "mock"): + self._name = name + + @property + def name(self) -> str: + return self._name + + def get_benchmark_specs(self, config): + return [BenchmarkSpec(provider=self._name, benchmark_name="test")] + + def get_env_specs(self, benchmark_spec): + return [ + EnvSpec(env_name="test_env", benchmark_name="test", provider=self._name) + ] + + def make_env(self, spec, submission_id=None, save_images=False): + return MagicMock() + + def get_observation_wrapper(self): + return None + + +class TestProviderRegistry: + """Tests for ProviderRegistry.""" + + def setup_method(self): + """Clear registry before each test.""" + ProviderRegistry.clear() + + def teardown_method(self): + """Clean up after each test.""" + ProviderRegistry.clear() + + def test_register_provider(self): + """Test registering a provider.""" + provider = MockProvider("test_provider") + ProviderRegistry.register(provider) + + assert ProviderRegistry.has_provider("test_provider") + assert "test_provider" in ProviderRegistry.list_providers() + + def test_register_duplicate_overwrites(self): + """Test that registering a duplicate provider overwrites.""" + provider1 = MockProvider("test") + provider2 = MockProvider("test") + + ProviderRegistry.register(provider1) + ProviderRegistry.register(provider2) + + assert ProviderRegistry.get("test") == provider2 + + def test_get_provider(self): + """Test getting a provider by name.""" + provider = MockProvider("test") + ProviderRegistry.register(provider) + + result = ProviderRegistry.get("test") + assert result == provider + + def test_get_nonexistent_provider(self): + """Test that getting a nonexistent provider raises KeyError.""" + with pytest.raises(KeyError) as excinfo: + ProviderRegistry.get("nonexistent") + + assert "nonexistent" in str(excinfo.value) + + def test_get_optional_provider(self): + """Test getting a provider optionally.""" + provider = MockProvider("test") + ProviderRegistry.register(provider) + + assert ProviderRegistry.get_optional("test") == provider + assert ProviderRegistry.get_optional("nonexistent") is None + + def test_unregister_provider(self): + """Test unregistering a provider.""" + provider = MockProvider("test") + ProviderRegistry.register(provider) + + result = ProviderRegistry.unregister("test") + + assert result is True + assert not ProviderRegistry.has_provider("test") + + def test_unregister_nonexistent_provider(self): + """Test unregistering a provider that doesn't exist.""" + result = ProviderRegistry.unregister("nonexistent") + assert result is False + + def test_list_providers(self): + """Test listing all registered providers.""" + provider1 = MockProvider("provider1") + provider2 = MockProvider("provider2") + + ProviderRegistry.register(provider1) + ProviderRegistry.register(provider2) + + providers = ProviderRegistry.list_providers() + assert len(providers) == 2 + assert "provider1" in providers + assert "provider2" in providers + + def test_has_provider(self): + """Test checking if a provider exists.""" + provider = MockProvider("test") + ProviderRegistry.register(provider) + + assert ProviderRegistry.has_provider("test") is True + assert ProviderRegistry.has_provider("nonexistent") is False + + def test_clear_providers(self): + """Test clearing all providers.""" + provider = MockProvider("test") + ProviderRegistry.register(provider) + + ProviderRegistry.clear() + + assert len(ProviderRegistry.list_providers()) == 0 + assert not ProviderRegistry.has_provider("test") + + +class TestBenchmarkSpec: + """Tests for BenchmarkSpec dataclass.""" + + def test_default_values(self): + """Test default values for BenchmarkSpec.""" + spec = BenchmarkSpec(provider="test", benchmark_name="bench1") + + assert spec.provider == "test" + assert spec.benchmark_name == "bench1" + assert spec.config == {} + assert spec.render_mode == "rgb_array" + assert spec.camera_names == ("corner",) + assert spec.camera_attribute == "camera_name" + + def test_str_representation(self): + """Test string representation.""" + spec = BenchmarkSpec(provider="metaworld", benchmark_name="MT10") + assert str(spec) == "metaworld/MT10" + + +class TestEnvSpec: + """Tests for EnvSpec dataclass.""" + + def test_default_values(self): + """Test default values for EnvSpec.""" + spec = EnvSpec(env_name="test_env", benchmark_name="bench1", provider="test") + + assert spec.env_name == "test_env" + assert spec.benchmark_name == "bench1" + assert spec.provider == "test" + assert spec.config == {} + assert spec.episodes_per_task == 3 + assert spec.max_episode_steps == 10 + assert spec.render_mode == "rgb_array" + + def test_str_representation(self): + """Test string representation.""" + spec = EnvSpec( + env_name="door-open-v2", benchmark_name="MT10", provider="metaworld" + ) + assert str(spec) == "metaworld/MT10/door-open-v2" diff --git a/src/evaluator/rollout/__init__.py b/src/evaluator/rollout/__init__.py index e40bbc0..18ee755 100644 --- a/src/evaluator/rollout/__init__.py +++ b/src/evaluator/rollout/__init__.py @@ -60,7 +60,6 @@ def create_worker( s3_config: Optional[S3Config] = None, episode_log_interval: int = 1, step_log_interval: int = 1, - database_url: Optional[str] = None, ) -> ray.actor.ActorHandle: logger.info( f"Creating worker: {rollout_worker_id}, {benchmark_specs}, {submission_container_host}, {submission_container_port}, {submission_id}" @@ -75,7 +74,6 @@ def create_worker( s3_config, episode_log_interval, step_log_interval, - database_url, ) self.workers.append(worker) return worker @@ -117,7 +115,6 @@ def __init__( s3_config: Optional[S3Config] = None, episode_log_interval: int = 1, step_log_interval: int = 1, - database_url: Optional[str] = None, ) -> None: logger.info( f"RolloutWorker init: {cluster_name}, {rollout_worker_id}, {benchmark_specs}, " @@ -140,7 +137,6 @@ def __init__( self.episode_log_interval = episode_log_interval self.step_log_interval = step_log_interval self.s3_config = s3_config - self.database_url = database_url self.episode_loggers: Dict[str, EpisodeLogger] = {} self._global_episode_counter = 1 @@ -299,7 +295,6 @@ async def run_env( step_log_interval=self.step_log_interval, enable_s3_upload=self.s3_config is not None, s3_config=self.s3_config, - database_url=self.database_url, ) # Generate a unique task ID combining benchmark and env names diff --git a/src/evaluator/rollout/episode_logger.py b/src/evaluator/rollout/episode_logger.py index 9f13ddd..17ef2f8 100644 --- a/src/evaluator/rollout/episode_logger.py +++ b/src/evaluator/rollout/episode_logger.py @@ -1,55 +1,26 @@ """ -Episode and step logging system with S3 storage integration. +Episode and step logging system. -This module provides utilities for logging episode data and step-level -observations to both database and S3 storage with configurable intervals. +This module provides utilities for tracking episode data during evaluation. +Real-time streaming of episode/step data and S3 uploads are not yet implemented +- only final aggregated results are sent after evaluation completes. """ -import asyncio import logging -from concurrent.futures import ( - ThreadPoolExecutor, - wait, -) -from concurrent.futures import ( - TimeoutError as FutureTimeoutError, -) from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -import asyncpg import numpy as np -from pgqueuer import Queries -from pgqueuer.db import AsyncpgDriver from core.constants import ImageFormat -from core.messages import EpisodeDataMessage, EpisodeStepDataMessage -from core.storage import S3Config, S3StorageClient +from core.storage import S3Config from .worker_utils import extract_success_flag logger = logging.getLogger(__name__) -# Upload configuration -OBS_UPLOAD_CLEANUP_TIMEOUT = 5.0 -OBS_UPLOAD_TIMEOUT = 20.0 -MAX_UPLOAD_WORKERS = 4 - -# Database connection pool configuration -DB_POOL_MIN_SIZE = 1 -DB_POOL_MAX_SIZE = 5 -DB_POOL_MAX_QUERIES = 50000 -DB_POOL_MAX_INACTIVE_LIFETIME = 300 # seconds -DB_POOL_COMMAND_TIMEOUT = 60 # seconds - -# Retry configuration -DB_CONNECTION_MAX_RETRIES = 3 -DB_CONNECTION_INITIAL_RETRY_DELAY = 1.0 # seconds -DB_CONNECTION_RETRY_BACKOFF = 2.0 # exponential backoff multiplier -DB_QUEUE_INITIAL_RETRY_DELAY = 0.5 # seconds - @dataclass class LoggingConfig: @@ -59,24 +30,27 @@ class LoggingConfig: episode_log_interval: int = 1 # Log every N episodes step_log_interval: int = 10 # Log every N steps within an episode - # Storage configuration - enable_s3_upload: bool = True + # Storage configuration (not currently used) + enable_s3_upload: bool = False s3_config: Optional[S3Config] = None - # Local storage fallback + # Local storage fallback (not currently used) local_save_dir: Optional[Path] = None - # Image settings + # Image settings (not currently used) image_format: ImageFormat = ImageFormat.PNG image_quality: int = 95 # For JPEG - # Database URL for pgqueuer - database_url: Optional[str] = None - @dataclass class EpisodeLogger: - """Logger for episode and step data with S3 integration.""" + """Logger for episode and step data. + + This logger tracks episode progress for aggregation into final results. + Real-time streaming of episode/step data to the backend and S3 uploads + are not implemented yet - only final aggregated results are sent after + evaluation completes. + """ config: LoggingConfig submission_id: str @@ -92,34 +66,6 @@ class EpisodeLogger: default_factory=list, init=False ) _current_episode_start: Optional[datetime] = field(default=None, init=False) - _storage_client: Optional[S3StorageClient] = field(default=None, init=False) - - # Background upload system - _upload_executor: Optional[ThreadPoolExecutor] = field(default=None, init=False) - _upload_futures: List[Any] = field(default_factory=list, init=False) - _pending_uploads: Dict[str, Any] = field(default_factory=dict, init=False) - - # Database connection pool - _db_pool: Optional[asyncpg.Pool] = field(default=None, init=False) - - def __post_init__(self): - """Initialize storage client and upload executor if S3 is enabled.""" - if self.config.enable_s3_upload and self.config.s3_config: - try: - self._storage_client = S3StorageClient(self.config.s3_config) - # Create thread pool for background uploads - self._upload_executor = ThreadPoolExecutor( - max_workers=MAX_UPLOAD_WORKERS, thread_name_prefix="s3_upload" - ) - logger.info("S3 storage client initialized with background upload pool") - except Exception as e: - logger.error(f"Failed to initialize S3 storage: {e}") - logger.warning("Falling back to local storage only") - self._storage_client = None - self._upload_executor = None - - # Initialize database pool asynchronously when needed - # Pool will be created on first use def start_episode(self, episode_id: int) -> None: """Start tracking a new episode. @@ -164,7 +110,7 @@ async def log_step( reward: Reward received done: Whether episode terminated truncated: Whether episode was truncated - observations: List of (image, camera_name) tuples + observations: List of (image, camera_name) tuples (not currently used) info: Additional info from environment """ logger.info("Logging step data") @@ -172,52 +118,17 @@ async def log_step( logger.warning("Attempted to log step without active episode") return - # Check if we should log this step based on interval - should_log_step = ( - (step % self.config.step_log_interval == 0) or done or truncated - ) - - if not should_log_step: - # Still accumulate basic data for episode summary - self._current_episode_steps.append( - { - "step": step, - "reward": reward, - "done": done, - "truncated": truncated, - "should_log": False, # Mark as not to be logged - } - ) - return - + # Accumulate step data for episode summary step_data = { "step": step, - "action": self._serialize_action(action), "reward": reward, "done": done, "truncated": truncated, - "timestamp": datetime.now(timezone.utc), "info": info or {}, } - # Schedule background upload of observations if available - if observations and self._storage_client: - upload_key = self._upload_observations_async( - observations, self._current_episode_id, step - ) - # Store placeholder for now, will be resolved before sending to backend - step_data["upload_key"] = upload_key - step_data["observation_refs"] = {} # Will be filled in later - else: - step_data["observation_refs"] = {} - step_data["upload_key"] = None - - # Mark step for later queuing if it should be logged - step_data["should_log"] = should_log_step self._current_episode_steps.append(step_data) - # Don't queue immediately - queue at end of episode to avoid race condition - async def end_episode( self, final_reward: float, @@ -235,471 +146,23 @@ async def end_episode( logger.warning("Attempted to end episode without active episode") return - # Coerce success to bool and fall back to logged step info if needed - final_success = bool(success) or self._infer_success_from_steps() - - episode_data = { - "job_id": self.job_id, - "submission_id": self.submission_id, - "task_id": self.task_id, - "episode_id": self._current_episode_id, - "env_name": self.env_name, - "benchmark_name": self.benchmark_name, - "final_reward": final_reward, - "success": final_success, - "steps": len(self._current_episode_steps), - "start_time": self._current_episode_start, - "end_time": datetime.now(timezone.utc), - "extra_metrics": extra_metrics or {}, - } - - # Wait for all background uploads to complete before sending data - self._wait_for_uploads(timeout=OBS_UPLOAD_TIMEOUT) - - # Queue episode to database if configured - if self.config.database_url: - logger.info( - "Queuing episode summary submission=%s task=%s episode=%s job=%s", - self.submission_id, - self.task_id, - self._current_episode_id, - self.job_id, - ) - enqueued = await self._queue_episode_data(episode_data) - if not enqueued: - logger.error( - "Failed to queue episode summary submission=%s task=%s episode=%s job=%s", - self.submission_id, - self.task_id, - self._current_episode_id, - self.job_id, - ) - - # Queue all step data after episode is queued to avoid race condition - for step_data in self._current_episode_steps: - if step_data.get("should_log", False): - await self._queue_step_data(step_data) + # Log episode completion (data is aggregated in final results, not streamed) + logger.debug( + "Episode %s completed: reward=%.3f success=%s steps=%d", + self._current_episode_id, + final_reward, + success, + len(self._current_episode_steps), + ) # Reset for next episode self._current_episode_id = None self._current_episode_steps = [] self._current_episode_start = None - def _serialize_action(self, action: Any) -> Dict: - """Serialize action to JSON-compatible format. - - Args: - action: Action in various formats - - Returns: - JSON-serializable dictionary - """ - if isinstance(action, np.ndarray): - return {"type": "array", "value": action.tolist(), "shape": action.shape} - elif isinstance(action, (list, tuple)): - return {"type": "list", "value": list(action)} - elif isinstance(action, dict): - return {"type": "dict", "value": action} - else: - return {"type": "scalar", "value": action} - - def _to_serializable(self, value: Any) -> Any: - """Recursively normalize common numpy/pydantic types for JSON.""" - - if isinstance(value, (str, int, float, bool)) or value is None: - return value - if isinstance(value, np.generic): - return value.item() - if isinstance(value, np.ndarray): - return value.tolist() - if isinstance(value, datetime): - return value.isoformat() - if isinstance(value, dict): - return {str(key): self._to_serializable(val) for key, val in value.items()} - if isinstance(value, (list, tuple, set)): - return [self._to_serializable(item) for item in value] - if hasattr(value, "model_dump"): - return self._to_serializable(value.model_dump()) - if hasattr(value, "__dict__"): - return self._to_serializable(vars(value)) - return str(value) - - def _upload_observations_sync( - self, - observations: List[Tuple[np.ndarray, str]], - episode_id: int, - step: int, - ) -> Dict[str, Dict[str, str]]: - """Synchronously upload observations to S3. - - This is called in a background thread by _upload_observations_async. - - Args: - observations: List of (image, camera_name) tuples - episode_id: Episode identifier - step: Step number - - Returns: - Dictionary mapping camera names to storage references - """ - refs = {} - - for image, camera_name in observations: - try: - result = self._storage_client.upload_observation_image( - image=image, - submission_id=self.submission_id, - task_id=self.task_id, - episode_id=episode_id, - step=step, - camera_name=camera_name, - fmt=self.config.image_format, - ) - - refs[camera_name] = { - "bucket": result["bucket"], - "key": result["key"], - "url": result["url"], - } - except Exception as e: - logger.error( - f"Failed to upload observation {camera_name} for step {step}: {e}" - ) - # Continue with other uploads even if one fails - - return refs - - def _upload_observations_async( - self, - observations: List[Tuple[np.ndarray, str]], - episode_id: int, - step: int, - ) -> str: - """Schedule observation uploads in background and return a placeholder key. - - Args: - observations: List of (image, camera_name) tuples - episode_id: Episode identifier - step: Step number - - Returns: - Placeholder key for the upload future - """ - if not self._upload_executor: - # Fallback to empty refs if executor not available - return None - - # Create a unique key for this upload - upload_key = f"{episode_id}_{step}" - - # Submit upload task to background thread pool - future = self._upload_executor.submit( - self._upload_observations_sync, observations, episode_id, step - ) - - # Store the future so we can retrieve results later - self._upload_futures.append(future) - self._pending_uploads[upload_key] = future - - logger.debug( - f"Scheduled background upload for episode {episode_id}, step {step}" - ) - - return upload_key - - def _wait_for_uploads(self, timeout: float = 30.0) -> None: - """Wait for all pending uploads to complete and update step data with results. - - Args: - timeout: Maximum time to wait for uploads in seconds - """ - if not self._pending_uploads: - return - - logger.info(f"Waiting for {len(self._pending_uploads)} uploads to complete") - - # Wait for all uploads to complete with timeout - try: - # Get all pending futures - futures = list(self._pending_uploads.values()) - - # Wait for completion with timeout - done, not_done = wait(futures, timeout=timeout) - - if not_done: - logger.warning( - f"{len(not_done)} uploads did not complete within {timeout}s timeout" - ) - - # Update step data with upload results - for step_data in self._current_episode_steps: - upload_key = step_data.get("upload_key") - if upload_key and upload_key in self._pending_uploads: - future = self._pending_uploads[upload_key] - if future.done(): - try: - # Get the upload results - refs = future.result(timeout=0.1) - step_data["observation_refs"] = refs - logger.debug(f"Retrieved upload results for {upload_key}") - except Exception as e: - logger.error( - f"Failed to get upload results for {upload_key}: {e}" - ) - step_data["observation_refs"] = {} - else: - logger.warning( - f"Upload {upload_key} not completed, using empty refs" - ) - step_data["observation_refs"] = {} - - # Clear pending uploads - self._pending_uploads.clear() - - except FutureTimeoutError: - logger.error(f"Upload wait timed out after {timeout}s") - # Set empty refs for all incomplete uploads - for step_data in self._current_episode_steps: - if step_data.get("upload_key") and not step_data.get( - "observation_refs" - ): - step_data["observation_refs"] = {} - - async def _ensure_db_pool(self) -> Optional[asyncpg.Pool]: - """Ensure database pool is created with retry logic. - - Returns: - Database connection pool or None if creation fails - """ - if self._db_pool is None and self.config.database_url: - retry_delay = DB_CONNECTION_INITIAL_RETRY_DELAY - - for attempt in range(DB_CONNECTION_MAX_RETRIES): - try: - self._db_pool = await asyncpg.create_pool( - dsn=self.config.database_url, - min_size=DB_POOL_MIN_SIZE, - max_size=DB_POOL_MAX_SIZE, - max_queries=DB_POOL_MAX_QUERIES, - max_inactive_connection_lifetime=DB_POOL_MAX_INACTIVE_LIFETIME, - command_timeout=DB_POOL_COMMAND_TIMEOUT, - server_settings={ - "jit": "off" # Disable JIT to avoid potential issues - }, - ) - logger.info("Database connection pool created successfully") - return self._db_pool - except Exception as e: - logger.warning( - f"Failed to create database pool (attempt {attempt + 1}/{DB_CONNECTION_MAX_RETRIES}): {e}" - ) - if attempt < DB_CONNECTION_MAX_RETRIES - 1: - await asyncio.sleep(retry_delay) - retry_delay *= DB_CONNECTION_RETRY_BACKOFF - else: - logger.error("Failed to create database pool after all retries") - return None - - # Check if existing pool is healthy - if self._db_pool: - try: - async with self._db_pool.acquire() as conn: - await conn.fetchval("SELECT 1") - return self._db_pool - except Exception as e: - logger.warning( - f"Database pool health check failed: {e}. Recreating pool..." - ) - await self._close_db_pool() - self._db_pool = None - return await self._ensure_db_pool() - - return self._db_pool - - async def _close_db_pool(self) -> None: - """Close the database connection pool.""" - if self._db_pool: - try: - await self._db_pool.close() - logger.info("Database connection pool closed") - except Exception as e: - logger.error(f"Error closing database pool: {e}") - finally: - self._db_pool = None - - async def _queue_with_retry( - self, queue_name: str, message_json: str, max_retries: int = None - ) -> bool: - """Queue a message with retry logic. - - Args: - queue_name: Name of the queue - message_json: JSON message to queue - max_retries: Maximum number of retry attempts (defaults to DB_CONNECTION_MAX_RETRIES) - - Returns: - True if queued successfully, False otherwise - """ - if max_retries is None: - max_retries = DB_CONNECTION_MAX_RETRIES - - pool = await self._ensure_db_pool() - if not pool: - logger.error("No database pool available for queuing") - return False - - retry_delay = DB_QUEUE_INITIAL_RETRY_DELAY - - for attempt in range(max_retries): - try: - async with pool.acquire() as conn: - driver = AsyncpgDriver(conn) - q = Queries(driver) - await q.enqueue([queue_name], [message_json.encode("utf-8")], [0]) - return True - - except asyncpg.exceptions.InterfaceError as e: - logger.warning( - f"Connection error on attempt {attempt + 1}/{max_retries}: {e}" - ) - # Connection is closed, recreate pool - await self._close_db_pool() - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - retry_delay *= DB_CONNECTION_RETRY_BACKOFF - - except Exception as e: - logger.error( - f"Unexpected error queuing message (attempt {attempt + 1}/{max_retries}): {e}" - ) - if attempt < max_retries - 1: - await asyncio.sleep(retry_delay) - retry_delay *= DB_CONNECTION_RETRY_BACKOFF - else: - return False - - return False - - async def _queue_episode_data(self, episode_data: Dict[str, Any]) -> bool: - """Queue episode data to be sent to backend via pgqueuer. - - Args: - episode_data: Episode data dictionary - """ - try: - # Convert datetime objects to ISO format for JSON serialization - episode_data_copy = episode_data.copy() - episode_data_copy["start_time"] = episode_data_copy[ - "start_time" - ].isoformat() - episode_data_copy["end_time"] = episode_data_copy["end_time"].isoformat() - - episode_data_copy = { - key: self._to_serializable(value) - for key, value in episode_data_copy.items() - } - - # Create message - episode_msg = EpisodeDataMessage(**episode_data_copy) - message_json = episode_msg.model_dump_json() - - # Queue with retry logic - success = await self._queue_with_retry("episode_data", message_json) - - if success: - logger.info( - "Queued episode summary submission=%s task=%s episode=%s job=%s", - episode_data["submission_id"], - episode_data["task_id"], - episode_data["episode_id"], - episode_data["job_id"], - ) - return True - - logger.error( - "Failed to queue episode %s after all retries", - episode_data["episode_id"], - ) - return False - - except Exception as e: - logger.error(f"Failed to prepare episode data for queuing: {e}") - return False - - async def _queue_step_data(self, step_data: Dict[str, Any]) -> None: - """Queue step data to be sent to backend via pgqueuer. - - Args: - step_data: Step data dictionary - """ - try: - # Convert datetime to ISO format - step_data_copy = step_data.copy() - - # Remove internal fields that shouldn't be sent - step_data_copy.pop("upload_key", None) - step_data_copy.pop("should_log", None) - - # Only process if this step data has full information (not just basic summary) - if "timestamp" not in step_data_copy: - logger.error("Step data missing timestamp - skipping queue") - return - - step_data_copy["step_timestamp"] = step_data_copy["timestamp"].isoformat() - step_data_copy["submission_id"] = self.submission_id - step_data_copy["task_id"] = self.task_id - step_data_copy["episode_id"] = self._current_episode_id - step_data_copy["job_id"] = self.job_id - step_data_copy["env_name"] = self.env_name - step_data_copy["benchmark_name"] = self.benchmark_name - - # Remove the original timestamp key since we renamed it - del step_data_copy["timestamp"] - - step_data_copy = { - key: self._to_serializable(value) - for key, value in step_data_copy.items() - } - - # Create message - step_msg = EpisodeStepDataMessage(**step_data_copy) - message_json = step_msg.model_dump_json() - - # Queue with retry logic - success = await self._queue_with_retry("episode_step_data", message_json) - - if success: - logger.debug(f"Queued step {step_data['step']} for backend") - else: - logger.error( - f"Failed to queue step {step_data['step']} after all retries" - ) - - except Exception as e: - logger.error(f"Failed to prepare step data for queuing: {e}") - def cleanup(self) -> None: - """Clean up resources, shutdown upload executor.""" - if self._upload_executor: - # Wait for any remaining uploads - self._wait_for_uploads(timeout=OBS_UPLOAD_CLEANUP_TIMEOUT) - - # Shutdown the executor - self._upload_executor.shutdown(wait=True, cancel_futures=True) - self._upload_executor = None - logger.info("Upload executor shutdown complete") - - # Close database pool - if self._db_pool: - # Create a new event loop if needed for async cleanup - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - loop.run_until_complete(self._close_db_pool()) - except Exception as e: - logger.error(f"Error during database pool cleanup: {e}") + """Clean up resources (no-op for now).""" + pass def get_statistics(self) -> Dict[str, Any]: """Get current logging statistics. @@ -711,5 +174,5 @@ def get_statistics(self) -> Dict[str, Any]: "episodes_tracked": self._episode_count, "current_episode_id": self._current_episode_id, "current_episode_steps": len(self._current_episode_steps), - "storage_enabled": self._storage_client is not None, + "storage_enabled": False, } diff --git a/src/evaluator/test.py b/src/evaluator/test.py deleted file mode 100644 index d3bad61..0000000 --- a/src/evaluator/test.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -from datetime import datetime - -import asyncpg -from pgqueuer import AsyncpgDriver, Queries -from snowflake import SnowflakeGenerator - -from validator.db.models import EvaluationJob, EvaluationStatus - - -async def main(): - gen = SnowflakeGenerator(42) - job_id = next(gen) - sub_id = next(gen) - job = EvaluationJob( - created_at=datetime.now(), - updated_at=datetime.now(), - status=EvaluationStatus.QUEUED, - submission_id=sub_id, # type: ignore - miner_hotkey="5CyY97KCfwRC5UZN58A1cLpZnMgSZAKWtqaaggUfzYiJ6B8d", - hf_repo_id="rishiad/default_submission", - env_provider="metaworld", - env_name="MT10", - id=job_id, # type: ignore - container_id=None, - ray_worker_id=None, - retry_count=0, - max_retries=3, - random_seed=None, - eval_start=None, - eval_end=None, - ) - - conn = await asyncpg.connect( - dsn="postgresql://myuser:mypassword@localhost:5432/kinitrodb" - ) - driver = AsyncpgDriver(conn) - q = Queries(driver) - job_bytes = job.to_bytes() - print(f"Enqueuing job: {job!r}") - await q.enqueue(["add_job"], [job_bytes], [0]) - print("Job enqueued successfully.") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/validator/__init__.py b/src/validator/__init__.py index 176dac6..8ba2b1a 100644 --- a/src/validator/__init__.py +++ b/src/validator/__init__.py @@ -1,3 +1,10 @@ """ -Validator +Kinitro Validator - Polling-based weight setter for Bittensor. + +The validator periodically fetches weights from the backend and sets them on chain. """ + +from .config import ValidatorConfig +from .lite_validator import LiteValidator + +__all__ = ["ValidatorConfig", "LiteValidator"] diff --git a/src/validator/__main__.py b/src/validator/__main__.py index 82a587d..6a3e85a 100644 --- a/src/validator/__main__.py +++ b/src/validator/__main__.py @@ -1,5 +1,5 @@ """ -Main entry point for WebSocket-based validator. +Main entry point for the polling-based lite validator. """ import asyncio @@ -8,24 +8,19 @@ from core.log import get_logger -from .config import ValidatorConfig, ValidatorMode +from .config import ValidatorConfig from .lite_validator import LiteValidator -from .websocket_validator import WebSocketValidator logger = get_logger(__name__) class ValidatorService: - """Service wrapper for the WebSocket validator.""" + """Service wrapper for the lite validator.""" def __init__(self): self.config = ValidatorConfig() self.validator = None self._shutdown_event = asyncio.Event() - mode_value = self.config.settings.get( - "validator_mode", ValidatorMode.FULL.value - ) - self.mode = ValidatorMode(mode_value) def setup_signal_handlers(self): """Setup signal handlers for graceful shutdown.""" @@ -40,15 +35,9 @@ def signal_handler(signum, frame): async def run(self): """Run the validator service.""" try: - logger.info("Starting Kinitro Validator Service (mode=%s)", self.mode.value) - - match self.mode: - case ValidatorMode.LITE: - self.validator = LiteValidator(self.config) - case ValidatorMode.FULL: - self.validator = WebSocketValidator(self.config) - case _: - raise ValueError(f"Unknown validator mode: {self.mode}") + logger.info("Starting Kinitro Lite Validator") + + self.validator = LiteValidator(self.config) # Start validator in background validator_task = asyncio.create_task(self.validator.start()) diff --git a/src/validator/config.py b/src/validator/config.py index 76e2cd7..a1be0a8 100644 --- a/src/validator/config.py +++ b/src/validator/config.py @@ -1,14 +1,7 @@ -from enum import Enum - from core.config import Config, ConfigOpts from core.constants import NeuronType -class ValidatorMode(str, Enum): - FULL = "full" - LITE = "lite" - - class ValidatorConfig(Config): def __init__(self): opts = ConfigOpts( @@ -17,93 +10,49 @@ def __init__(self): settings_files=["validator.toml"], ) super().__init__(opts) - self._normalize_validator_mode() def add_args(self): """Add command line arguments""" super().add_args() - # pg database - self._parser.add_argument( - "--pg-database", - type=str, - help="PostgreSQL database URL", - default=self.settings.get( - "pg_database", "postgresql://user:password@localhost/dbname" - ), # type: ignore - ) - self._parser.add_argument( "--backend-url", type=str, - help="Backend WebSocket URL for validator connections", - default=self.settings.get( - "backend_url", "ws://localhost:8080/ws/validator" - ), + help="Backend HTTP URL for fetching weights", + default=self.settings.get("backend_url", "http://localhost:8080"), ) self._parser.add_argument( - "--reconnect-interval", + "--weight-poll-interval", type=int, - help="Seconds to wait before reconnecting to backend", - default=self.settings.get("reconnect_interval", 5), - ) - - self._parser.add_argument( - "--heartbeat-interval", - type=int, - help="Seconds between heartbeat messages to backend", - default=self.settings.get("heartbeat_interval", 30), - ) - - self._parser.add_argument( - "--validator-mode", - type=str, - choices=tuple(mode.value for mode in ValidatorMode), - help="Validator service mode", - default=self.settings.get("validator_mode", ValidatorMode.FULL.value), + help="Seconds between weight polling cycles", + default=self.settings.get("weight_poll_interval", 300), ) self._parser.add_argument( "--weights-url", type=str, - help="HTTP endpoint that exposes weight snapshots", + help="HTTP endpoint that exposes weight snapshots (deprecated, use backend-url)", default=self.settings.get("weights_url", "https://api.kinitro.ai/weights"), ) self._parser.add_argument( "--weights-poll-interval", type=float, - help="Seconds between weight snapshot polls in lite mode", - default=self.settings.get("weights_poll_interval", 30.0), + help="Seconds between weight snapshot polls (deprecated, use weight-poll-interval)", + default=self.settings.get("weights_poll_interval", 300.0), ) self._parser.add_argument( "--weights-request-timeout", type=float, - help="Timeout (seconds) for weight snapshot HTTP requests in lite mode", - default=self.settings.get("weights_request_timeout", 10.0), + help="Timeout (seconds) for weight snapshot HTTP requests", + default=self.settings.get("weights_request_timeout", 30.0), ) self._parser.add_argument( "--weights-stale-threshold", type=float, help="Warn if backend weight snapshot is older than this many seconds", - default=self.settings.get("weights_stale_threshold", 180.0), + default=self.settings.get("weights_stale_threshold", 900.0), ) - - def _normalize_validator_mode(self) -> None: - """Ensure validator_mode is set to a supported value.""" - - raw_mode = self.settings.get("validator_mode", ValidatorMode.FULL.value) - if isinstance(raw_mode, ValidatorMode): - normalized = raw_mode.value - else: - normalized = str(raw_mode).lower() - - try: - ValidatorMode(normalized) - except ValueError as exc: - raise ValueError(f"Invalid validator_mode '{raw_mode}'") from exc - - self.settings["validator_mode"] = normalized diff --git a/src/validator/lite_validator.py b/src/validator/lite_validator.py index 2a317fc..cfbc62f 100644 --- a/src/validator/lite_validator.py +++ b/src/validator/lite_validator.py @@ -1,17 +1,19 @@ """ -Lite validator that polls the backend HTTP weights endpoint and writes weights on-chain. -""" +Lite validator for Kinitro. -from __future__ import annotations +This validator periodically fetches weights from the backend and sets them +on the Bittensor chain. This is a simple polling approach that avoids the +complexity of maintaining WebSocket connections. +""" import asyncio -from datetime import datetime, timezone -from typing import Optional +from typing import Dict, Optional import httpx -from pydantic import ValidationError +from fiber.chain.fetch_nodes import _get_nodes_for_uid +from fiber.chain.models import Node -from backend.models import WeightsSnapshot +from backend.models import SS58Address from core.chain import set_node_weights from core.log import get_logger from core.neuron import Neuron @@ -23,274 +25,189 @@ class LiteValidator(Neuron): """ - Minimal validator implementation that periodically polls the backend weights endpoint - and calls `set_node_weights` on the Bittensor chain. + Polling-based validator that fetches weights and sets them on chain. + + This validator: + 1. Periodically polls the backend /weights endpoint + 2. Sets the received weights on the Bittensor chain + + Much simpler than maintaining a WebSocket connection. """ def __init__(self, config: ValidatorConfig): super().__init__(config) self.hotkey = self.keypair.ss58_address - self.weights_url = config.settings.get( - "weights_url", "https://api.kinitro.ai/weights" - ) - self.poll_interval = float(config.settings.get("weights_poll_interval", 30.0)) - self.request_timeout = float( - config.settings.get("weights_request_timeout", 10.0) - ) - self.stale_threshold = float( - config.settings.get("weights_stale_threshold", 180.0) - ) + + # Backend settings + self.backend_url = config.settings.get( + "backend_url", "http://localhost:8080" + ).rstrip("/") + self.poll_interval = config.settings.get( + "weight_poll_interval", 300 + ) # 5 min default + + # Chain state + self.nodes: Optional[Dict[SS58Address, Node]] = None + self.validator_node_id: Optional[int] = None + + # Track last weights to avoid redundant chain calls + self._last_weights_hash: Optional[int] = None + self._running = False - self._stop_event = asyncio.Event() - self._last_snapshot_timestamp: Optional[datetime] = None - self._last_weights_signature: Optional[tuple[tuple[int, float], ...]] = None - self._last_backend_timestamp: Optional[datetime] = None - self._last_success_at: Optional[datetime] = None - self._max_backoff = max(self.poll_interval * 10.0, 300.0) - self._node_resync_interval = max(self.stale_threshold, 300.0) - self._last_resync_at: Optional[datetime] = None - self._http_client: Optional[httpx.AsyncClient] = None logger.info( - "Lite validator initialized (hotkey=%s, weights_url=%s, poll_interval=%.1fs, stale_threshold=%.1fs)", + "LiteValidator initialized for hotkey: %s, polling every %ds", self.hotkey, - self.weights_url, self.poll_interval, - self.stale_threshold, ) - async def start(self) -> None: - """Begin the polling loop.""" - if self._running: - logger.warning("Lite validator already running") - return - - logger.info("Starting lite validator polling loop") + async def start(self): + """Start the validator service.""" + logger.info("Starting LiteValidator") self._running = True - self._stop_event.clear() - backoff = self.poll_interval - self._http_client = httpx.AsyncClient( - timeout=httpx.Timeout(self.request_timeout) - ) - - try: - while self._running: - success = await self._poll_once() - backoff = ( - self.poll_interval - if success - else min(backoff * 2.0, self._max_backoff) - ) - try: - await asyncio.wait_for(self._stop_event.wait(), timeout=backoff) - except asyncio.TimeoutError: - continue - finally: - self._running = False - self._stop_event.set() - await self._close_http_client() - logger.info("Lite validator loop exited") - - async def stop(self) -> None: - """Stop the polling loop.""" - if not self._running: - return - - logger.info("Stopping lite validator") - self._running = False - self._stop_event.set() - await self._close_http_client() + # Initialize chain connection + await self._init_chain() - async def _close_http_client(self) -> None: - client = self._http_client - if client is None: - return - self._http_client = None - try: - await client.aclose() - except Exception as exc: # pragma: no cover - logger.debug("Error closing HTTP client: %s", exc) - - async def _poll_once(self) -> bool: - """Fetch and process a single weight snapshot.""" - try: - snapshot = await self._fetch_weights_snapshot() - except Exception as exc: # pragma: no cover - defensive logging - logger.error("Failed to fetch weights snapshot: %s", exc) - return False - - if snapshot is None: - logger.debug("No weights snapshot available yet") - if self._last_backend_timestamp: - age = ( - datetime.now(timezone.utc) - self._last_backend_timestamp - ).total_seconds() - if age > self.stale_threshold: - logger.warning( - "No fresh weight snapshots received for %.1fs (threshold=%.1fs)", - age, - self.stale_threshold, - ) - return False - - if self._stop_event.is_set(): - logger.debug("Shutdown requested; skipping snapshot processing") - return False - - now = datetime.now(timezone.utc) - - try: - await self._handle_snapshot(snapshot, now) - except Exception as exc: # pragma: no cover - defensive logging - logger.error("Error handling weights snapshot: %s", exc) - return False + # Main polling loop + while self._running: + try: + await self._poll_and_set_weights() + except Exception as e: + logger.error(f"Error in weight polling cycle: {e}") - return True - - async def _fetch_weights_snapshot(self) -> Optional[WeightsSnapshot]: - """Retrieve the latest weights via HTTP.""" - - if self._stop_event.is_set(): - return None - - client = self._http_client - if client is None: - raise RuntimeError("HTTP client not initialized") - - try: - response = await client.get(self.weights_url) - except httpx.RequestError as exc: if self._running: - logger.warning("Weight endpoint request failed: %s", exc) - return None - - if response.status_code == 404: - return None + await asyncio.sleep(self.poll_interval) - try: - response.raise_for_status() - except httpx.HTTPStatusError as exc: - logger.error("Unexpected response from weight endpoint: %s", exc) - return None - - try: - payload = response.json() - except ValueError as exc: - logger.error("Failed to decode weights JSON payload: %s", exc) - return None + async def stop(self): + """Stop the validator service.""" + logger.info("Stopping LiteValidator") + self._running = False + async def _poll_and_set_weights(self): + """Fetch weights from backend and set on chain.""" try: - return WeightsSnapshot.model_validate(payload) - except ValidationError as exc: - logger.error("Failed to validate weights snapshot payload: %s", exc) - return None - - async def _handle_snapshot(self, snapshot: WeightsSnapshot, now: datetime) -> bool: - """Validate and, if new, apply the weight snapshot.""" - - timestamp = snapshot.updated_at - if timestamp.tzinfo is None: - timestamp = timestamp.replace(tzinfo=timezone.utc) - - snapshot_label = timestamp.isoformat() - self._last_backend_timestamp = timestamp - age = (now - timestamp).total_seconds() - if age > self.stale_threshold: - logger.warning( - "Weights snapshot (updated_at=%s) is stale by %.1fs (threshold=%.1fs)", - snapshot_label, - age, - self.stale_threshold, + # Fetch weights from backend + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(f"{self.backend_url}/weights") + + if response.status_code == 404: + logger.debug("Weights not available yet from backend") + return + + response.raise_for_status() + data = response.json() + + weights = data.get("weights", {}) + if not weights: + logger.debug("Empty weights received from backend") + return + + # Convert string keys to int (JSON serializes int keys as strings) + weights = {int(k): float(v) for k, v in weights.items()} + + # Check if weights changed + weights_hash = hash(frozenset(weights.items())) + if weights_hash == self._last_weights_hash: + logger.debug("Weights unchanged, skipping chain update") + return + + logger.info( + "Received new weights from backend: %d UIDs, total=%.4f", + len(weights), + sum(weights.values()), ) - weights_payload = snapshot.weights - if not weights_payload: - logger.error("No valid weights found in snapshot (%s)", snapshot_label) - return False - - node_ids = [] - node_weights = [] - for node_id, weight in weights_payload.items(): - node_ids.append(int(node_id)) - node_weights.append(float(weight)) - - weights_signature = tuple( - sorted( - (int(node_id), float(weight)) - for node_id, weight in weights_payload.items() - ) - ) - if self._last_snapshot_timestamp: - if timestamp <= self._last_snapshot_timestamp and ( - self._last_weights_signature == weights_signature - ): - logger.debug( - "Skipping snapshot updated_at=%s (duplicate or older than last applied)", - snapshot_label, - ) - return False - elif self._last_weights_signature == weights_signature: - logger.debug( - "Skipping snapshot (%s) with identical weights payload", - snapshot_label, - ) - return False + # Set weights on chain + await self._set_weights_on_chain(weights) + self._last_weights_hash = weights_hash + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error fetching weights: {e.response.status_code}") + except httpx.RequestError as e: + logger.error(f"Request error fetching weights: {e}") + except Exception as e: + logger.error(f"Error polling weights: {e}") + + async def _set_weights_on_chain(self, weights: Dict[int, float]): + """Set weights on the Bittensor chain.""" + if not self.substrate: + logger.error("Chain connection not initialized, cannot set weights") + return - # Periodically refresh node metadata to stay aligned with on-chain state. - if self._should_resync_nodes(now): - logger.info("Refreshing node metadata from chain before setting weights") - self.sync_nodes() - self._last_resync_at = now + # Sync nodes to get latest state + logger.info("Syncing nodes before setting weights...") + await self._sync_nodes() + + # Get validator node_id if not already set + if self.validator_node_id is None: + validator_node = self.nodes.get(self.hotkey) if self.nodes else None + if validator_node: + self.validator_node_id = validator_node.node_id + else: + logger.error( + f"Validator hotkey {self.hotkey} not found in nodes, cannot set weights" + ) + return - total_weight = sum(node_weights) - logger.info( - "Applying weights snapshot updated_at=%s to chain (%s miners, total_weight=%.6f)", - snapshot_label, - len(node_ids), - total_weight, - ) - logger.debug( - "Weight payload for updated_at=%s: %s", - snapshot_label, - ", ".join( - f"{uid}:{weight:.6f}" for uid, weight in zip(node_ids, node_weights) - ), - ) + # Extract node_ids and weights as parallel lists + node_ids = list(weights.keys()) + node_weights = list(weights.values()) + # Set weights on chain + logger.info(f"Setting weights on chain for {len(node_ids)} miners") success = set_node_weights( substrate=self.substrate, keypair=self.keypair, node_ids=node_ids, node_weights=node_weights, - netuid=self.netuid, - validator_node_id=self.uid, + netuid=self.config.settings["subtensor"]["netuid"], + validator_node_id=self.validator_node_id, version_key=0, wait_for_inclusion=True, wait_for_finalization=False, ) - if not success: - logger.error( - "Failed to set weights on-chain for snapshot updated_at=%s", - snapshot_label, - ) - return False + if success: + logger.info(f"Successfully set weights on chain for {len(node_ids)} miners") + else: + logger.error("Failed to set weights on chain") - self._last_snapshot_timestamp = timestamp - self._last_weights_signature = weights_signature - logger.info( - "Successfully set weights on-chain for snapshot updated_at=%s (processed %s miners)", - snapshot_label, - len(node_ids), - ) - return True - - def _should_resync_nodes(self, now: datetime) -> bool: - if self.nodes is None: - return True - if self._last_resync_at is None: - return True - return ( - now - self._last_resync_at - ).total_seconds() >= self._node_resync_interval + async def _init_chain(self) -> None: + """Initialize blockchain info.""" + try: + logger.info("Getting nodes from chain...") + + # Sync nodes from chain + await self._sync_nodes() + + # Get our validator node_id from the nodes + validator_node = self.nodes.get(self.hotkey) if self.nodes else None + if validator_node: + self.validator_node_id = validator_node.node_id + logger.info(f"Validator node_id: {self.validator_node_id}") + else: + logger.warning(f"Validator hotkey {self.hotkey} not found in nodes") + self.validator_node_id = None + + logger.info("Blockchain connection initialized") + except Exception as e: + logger.error(f"Failed to initialize blockchain connection: {e}") + logger.warning("Continuing without blockchain connection") + + async def _sync_nodes(self) -> None: + """Sync nodes from the chain.""" + try: + loop = asyncio.get_event_loop() + node_list = await loop.run_in_executor( + None, + _get_nodes_for_uid, + self.substrate, + self.config.settings["subtensor"]["netuid"], + ) + self.nodes = {node.hotkey: node for node in node_list} + logger.info(f"Synced {len(self.nodes)} nodes") + except Exception as e: + logger.error(f"Failed to sync nodes: {e}") + if not self.nodes: + self.nodes = {} diff --git a/src/validator/websocket_validator.py b/src/validator/websocket_validator.py deleted file mode 100644 index b6c2041..0000000 --- a/src/validator/websocket_validator.py +++ /dev/null @@ -1,640 +0,0 @@ -""" -WebSocket-based validator service for Kinitro. - -This new validator architecture connects directly to the Kinitro Backend -via WebSocket and receives evaluation jobs from there -""" - -import asyncio -import json -import os -from typing import Dict, Optional - -import asyncpg -import dotenv -import websockets -from fiber.chain.fetch_nodes import _get_nodes_for_uid -from fiber.chain.models import Node -from pgqueuer import Job, PgQueuer, Queries -from pgqueuer.db import AsyncpgDriver -from websockets.exceptions import ConnectionClosed, WebSocketException - -from backend.models import SS58Address -from core.chain import set_node_weights -from core.log import get_logger -from core.messages import ( - EpisodeDataMessage, - EpisodeStepDataMessage, - EvalJobMessage, - EvalResultMessage, - HeartbeatMessage, - JobStatusUpdateMessage, - MessageType, - SetWeightsMessage, - ValidatorRegisterMessage, -) -from core.neuron import Neuron - -from .config import ValidatorConfig - -dotenv.load_dotenv() - -logger = get_logger(__name__) - -VALIDATOR_SEND_QUEUE_MAXSIZE = 1000 -VALIDATOR_SEND_QUEUE_WARN_FRACTION = 0.8 - - -class WebSocketValidator(Neuron): - """ - WebSocket-based validator that connects to the Kinitro backend. - """ - - def __init__(self, config: ValidatorConfig): - super().__init__(config) - self.hotkey = self.keypair.ss58_address - - # Backend connection settings - self.backend_url = config.settings.get( - "backend_url", "ws://localhost:8080/ws/validator" - ) - self.reconnect_interval = config.settings.get("reconnect_interval", 5) - self.heartbeat_interval = config.settings.get("heartbeat_interval", 30) - - # Get API key from environment variable only - self.api_key = os.environ.get("KINITRO_API_KEY") - if not self.api_key: - raise ValueError( - "Backend API key not provided. Set KINITRO_API_KEY environment variable" - ) - - # Connection state - self.websocket: Optional[websockets.ServerConnection] = None - self.connected = False - self._running = False - self._heartbeat_task = None - self._result_processor_task = None - self._send_queue: Optional[asyncio.Queue[dict]] = None - self._sender_task: Optional[asyncio.Task] = None - - # Database and pgqueue - self.database_url = config.settings.get( - "pg_database", "postgresql://myuser:mypassword@localhost/validatordb" - ) - - if self.database_url is None: - raise Exception("Database URL not provided") - - self.nodes: Optional[Dict[SS58Address, Node]] = None - self.validator_node_id: int = None # Our node ID on the chain - - logger.info(f"WebSocket Validator initialized for hotkey: {self.hotkey}") - - async def start(self): - """Start the validator service.""" - logger.info("Starting WebSocket Validator") - self._running = True - - # Initialize chain connection - await self._init_chain() - - # Start the result processor task - self._result_processor_task = asyncio.create_task(self._process_results()) - - # Connect to backend with auto-reconnect - while self._running: - try: - await self._connect_to_backend() - # Connection lost, wait before retry - if self._running: - logger.warning( - f"Connection lost, reconnecting in {self.reconnect_interval}s" - ) - await asyncio.sleep(self.reconnect_interval) - except Exception as e: - logger.error(f"Failed to connect to backend: {e}") - if self._running: - await asyncio.sleep(self.reconnect_interval) - - async def stop(self): - """Stop the validator service.""" - logger.info("Stopping WebSocket Validator") - self._running = False - - # Cancel heartbeat - if self._heartbeat_task: - self._heartbeat_task.cancel() - try: - await self._heartbeat_task - except asyncio.CancelledError: - pass - self._heartbeat_task = None - - if self._send_queue: - try: - self._send_queue.put_nowait(None) - except asyncio.QueueFull: - pass - - if self._sender_task: - self._sender_task.cancel() - try: - await self._sender_task - except asyncio.CancelledError: - pass - self._sender_task = None - - self._send_queue = None - - # Cancel result processor - if self._result_processor_task: - self._result_processor_task.cancel() - try: - await self._result_processor_task - except asyncio.CancelledError: - pass - - # Close WebSocket connection - if self.websocket: - await self.websocket.close() - self.connected = False - - logger.info("WebSocket Validator stopped") - - async def _connect_to_backend(self): - """Connect to backend and handle messages.""" - try: - logger.info(f"Connecting to backend: {self.backend_url}") - - async with websockets.connect( - self.backend_url, - # NOTE: Disabled the library keepalive pings so we rely on our - # application-level heartbeat rather than closing the - # connection when the event loop is busy. - ping_interval=None, - ping_timeout=None, - close_timeout=10, - ) as websocket: - self.websocket = websocket - self._send_queue = asyncio.Queue(maxsize=VALIDATOR_SEND_QUEUE_MAXSIZE) - self._sender_task = asyncio.create_task(self._sender_loop()) - - # Register with backend - await self._register() - - # Start heartbeat - self._heartbeat_task = asyncio.create_task(self._heartbeat_loop()) - - # Handle messages - await self._message_loop() - - except ConnectionClosed: - logger.warning("Backend connection closed") - except WebSocketException as e: - logger.error(f"WebSocket error: {e}") - except Exception as e: - logger.error(f"Connection error: {e}") - finally: - self.connected = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - try: - await self._heartbeat_task - except asyncio.CancelledError: - pass - self._heartbeat_task = None - - if self._send_queue: - try: - self._send_queue.put_nowait(None) - except asyncio.QueueFull: - pass - - if self._sender_task: - self._sender_task.cancel() - try: - await self._sender_task - except asyncio.CancelledError: - pass - self._sender_task = None - - self._send_queue = None - - async def _register(self): - """Register validator with backend.""" - register_msg = ValidatorRegisterMessage( - hotkey=self.hotkey, api_key=self.api_key - ) - await self._send_message(register_msg.model_dump()) - - # Wait for acknowledgment - response = await self.websocket.recv() - ack = json.loads(response) - - if ( - ack.get("message_type") == MessageType.REGISTRATION_ACK - and ack.get("status") == "registered" - ): - self.connected = True - logger.info("Successfully registered with backend") - else: - raise Exception(f"Registration failed: {ack}") - - async def _heartbeat_loop(self): - """Send periodic heartbeats to backend.""" - try: - while self.connected and not self._heartbeat_task.cancelled(): - heartbeat = HeartbeatMessage() - await self._send_message(heartbeat.model_dump()) - - # Wait for heartbeat interval - await asyncio.sleep(self.heartbeat_interval) - - except asyncio.CancelledError: - logger.debug("Heartbeat task cancelled") - except Exception as e: - logger.error(f"Heartbeat error: {e}") - - async def _message_loop(self): - """Handle incoming messages from backend.""" - try: - async for message in self.websocket: - try: - data = json.loads(message) - message_type = data.get("message_type") - - if message_type == MessageType.EVAL_JOB: - await self._handle_eval_job(EvalJobMessage(**data)) - elif message_type == MessageType.SET_WEIGHTS: - await self._handle_set_weights(SetWeightsMessage(**data)) - elif message_type == MessageType.HEARTBEAT_ACK: - logger.debug("Received heartbeat ack") - elif message_type == MessageType.ERROR: - logger.error(f"Backend error: {data.get('error')}") - else: - logger.warning(f"Unknown message type: {message_type}") - - except json.JSONDecodeError as e: - logger.error(f"Failed to decode message: {e}") - except Exception as e: - logger.error(f"Error handling message: {e}") - - except ConnectionClosed: - logger.info("Message loop ended - connection closed") - except Exception as e: - logger.error(f"Message loop error: {e}") - - async def _handle_eval_job(self, job: EvalJobMessage): - """Handle evaluation job from backend.""" - logger.info( - f"Received evaluation job: {job.job_id} for miner {job.miner_hotkey}" - ) - - # Track active job - # self.active_jobs[job.job_id] = job - - # Queue the job with pgqueuer to the database - job_bytes = job.to_bytes() - logger.info(f"Queueing job {job.job_id} to database") - # Connect to the postgres database - conn = await asyncpg.connect(dsn=self.database_url) - driver = AsyncpgDriver(conn) - q = Queries(driver) - await q.enqueue(["add_job"], [job_bytes], [0]) - logger.info(f"Job {job.job_id} queued successfully") - - async def _process_results(self): - """Process evaluation results from pgqueue and send them to the backend.""" - logger.info("Starting result processor task") - - try: - # Connect to the postgres database - conn = await asyncpg.connect(dsn=self.database_url) - driver = AsyncpgDriver(conn) - pgq = PgQueuer(driver) - - @pgq.entrypoint("eval_result") - async def process_result(job: Job) -> None: - """Process an evaluation result from the queue.""" - try: - # Parse the result from the job payload - result_data = json.loads(job.payload.decode("utf-8")) - eval_result = EvalResultMessage(**result_data) - - logger.info( - f"Processing evaluation result for job {eval_result.job_id}" - ) - - # Send the result to the backend if connected - if self.connected and self.websocket: - await self._send_eval_result(eval_result) - logger.info( - f"Sent evaluation result for job {eval_result.job_id} to backend" - ) - else: - # If not connected, the job will remain in queue and be retried - logger.warning( - f"Not connected to backend, result for job {eval_result.job_id} will be retried" - ) - raise Exception("Not connected to backend") - - except Exception as e: - logger.error(f"Failed to process evaluation result: {e}") - # Re-raise to let pgqueue handle retry - raise - - @pgq.entrypoint("episode_data") - async def process_episode_data(job: Job) -> None: - """Process episode data from the queue.""" - try: - # Parse the episode data from the job payload - episode_data = json.loads(job.payload.decode("utf-8")) - episode_msg = EpisodeDataMessage(**episode_data) - - logger.info( - f"Processing episode data for submission {episode_msg.submission_id}, episode {episode_msg.episode_id}" - ) - - # Send the episode data to the backend if connected - if self.connected and self.websocket: - await self._send_episode_data(episode_msg) - logger.info( - f"Sent episode data for episode {episode_msg.episode_id} to backend" - ) - else: - # If not connected, the job will remain in queue and be retried - logger.warning( - f"Not connected to backend, episode data for episode {episode_msg.episode_id} will be retried" - ) - raise Exception("Not connected to backend") - - except Exception as e: - logger.error(f"Failed to process episode data: {e}") - # Re-raise to let pgqueue handle retry - raise - - @pgq.entrypoint("episode_step_data") - async def process_episode_step_data(job: Job) -> None: - """Process episode step data from the queue.""" - try: - # Parse the step data from the job payload - step_data = json.loads(job.payload.decode("utf-8")) - step_msg = EpisodeStepDataMessage(**step_data) - - logger.info( - f"Processing step data for submission {step_msg.submission_id}, episode {step_msg.episode_id}, step {step_msg.step}" - ) - - # Send the step data to the backend if connected - if self.connected and self.websocket: - await self._send_episode_step_data(step_msg) - logger.info( - f"Sent step data for episode {step_msg.episode_id}, step {step_msg.step} to backend" - ) - else: - # If not connected, the job will remain in queue and be retried - logger.warning( - f"Not connected to backend, step data for episode {step_msg.episode_id}, step {step_msg.step} will be retried" - ) - raise Exception("Not connected to backend") - - except Exception as e: - logger.error(f"Failed to process episode step data: {e}") - # Re-raise to let pgqueue handle retry - raise - - @pgq.entrypoint("job_status_update") - async def process_job_status_update(job: Job) -> None: - """Process job status updates from the queue.""" - try: - status_data = json.loads(job.payload.decode("utf-8")) - status_msg = JobStatusUpdateMessage(**status_data) - - logger.info( - "Processing job status update for job %s: %s", - status_msg.job_id, - status_msg.status, - ) - - if self.connected and self.websocket: - await self._send_job_status_update(status_msg) - logger.info( - "Sent job status update for job %s to backend", - status_msg.job_id, - ) - else: - logger.warning( - "Not connected to backend, job status update for job %s will be retried", - status_msg.job_id, - ) - raise Exception("Not connected to backend") - - except Exception as e: - logger.error(f"Failed to process job status update: {e}") - raise - - logger.info( - "Result processor is now listening for evaluation data and status updates..." - ) - await pgq.run() - - except asyncio.CancelledError: - logger.info("Result processor task cancelled") - raise - except Exception as e: - logger.error(f"Result processor error: {e}") - # Restart the processor after a delay if still running - if self._running: - await asyncio.sleep(5) - self._result_processor_task = asyncio.create_task( - self._process_results() - ) - - async def _send_eval_result(self, result: EvalResultMessage): - """Send evaluation result to the backend.""" - await self._send_message(result.model_dump(mode="json")) - - async def _send_job_status_update(self, status_update: JobStatusUpdateMessage): - """Send job status update to the backend.""" - await self._send_message(status_update.model_dump(mode="json")) - - async def _send_episode_data(self, episode_data: EpisodeDataMessage): - """Send episode data to the backend.""" - await self._send_message(episode_data.model_dump()) - - async def _send_episode_step_data(self, step_data: EpisodeStepDataMessage): - """Send episode step data to the backend.""" - await self._send_message(step_data.model_dump()) - - async def _sender_loop(self) -> None: - """Continuously flush the outbound queue to the backend.""" - - if not self._send_queue: - return - - try: - while True: - message = await self._send_queue.get() - - if message is None: - self._send_queue.task_done() - break - - try: - if not self.websocket: - raise RuntimeError("WebSocket connection unavailable") - - message_json = json.dumps(message, default=str) - await self.websocket.send(message_json) - except Exception as exc: - logger.error(f"Outbound sender error: {exc}") - raise - finally: - self._send_queue.task_done() - - except asyncio.CancelledError: - logger.error("Outbound sender task cancelled") - raise - except Exception: - if self.websocket: - try: - await self.websocket.close() - except Exception: - logger.error("Error closing WebSocket connection") - finally: - if self._send_queue: - while not self._send_queue.empty(): - try: - self._send_queue.get_nowait() - self._send_queue.task_done() - except asyncio.QueueEmpty: - logger.warning("Outbound queue already empty") - break - - async def _init_chain(self) -> None: - """Initialize blockchain info.""" - try: - logger.info("Getting nodes from chain...") - - # Sync nodes from chain - await self._sync_nodes() - - # Get our validator node_id from the nodes - validator_node = self.nodes.get(self.hotkey) if self.nodes else None - if validator_node: - self.validator_node_id = validator_node.node_id - logger.info(f"Validator node_id: {self.validator_node_id}") - else: - logger.warning(f"Validator hotkey {self.hotkey} not found in nodes") - self.validator_node_id = None - - logger.info("Blockchain connection initialized") - except Exception as e: - logger.error(f"Failed to initialize blockchain connection: {e}") - # Continue without chain connection - validator can still process jobs - logger.warning("Continuing without blockchain connection") - - async def _sync_nodes(self) -> None: - """Sync nodes from the chain.""" - try: - # Run in thread pool to avoid blocking - loop = asyncio.get_event_loop() - node_list = await loop.run_in_executor( - None, - _get_nodes_for_uid, - self.substrate, - self.config.settings["subtensor"]["netuid"], - ) - self.nodes = {node.hotkey: node for node in node_list} - logger.info(f"Synced {len(self.nodes)} nodes") - except Exception as e: - logger.error(f"Failed to sync nodes: {e}") - if not self.nodes: - self.nodes = {} - - async def _handle_set_weights(self, weights_msg: SetWeightsMessage): - """Handle weight setting message from backend. - - This function: - 1. Receives weights dict (UID->weight mapping) from the backend - 2. Syncs the nodes to get latest chain state - 3. Sets the weights on chain using the validator's keypair - """ - try: - logger.info( - f"Received weight update: {len(weights_msg.weights)} weights for miners {list(weights_msg.weights.keys())[:5]}..." - ) - - if not self.substrate or not self.nodes: - logger.error("Chain connection not initialized, cannot set weights") - return - - # Sync nodes to get latest state - logger.info("Syncing nodes before setting weights...") - await self._sync_nodes() - - # Get validator node_id if not already set - if self.validator_node_id is None: - validator_node = self.nodes.get(self.hotkey) if self.nodes else None - if validator_node: - self.validator_node_id = validator_node.node_id - else: - logger.error( - f"Validator hotkey {self.hotkey} not found in nodes, cannot set weights" - ) - return - - # Extract node_ids and weights as parallel lists for the set_node_weights function - node_ids = list(weights_msg.weights.keys()) - node_weights = list(weights_msg.weights.values()) - - # Set weights on chain - logger.info(f"Setting weights on chain for {len(node_ids)} miners") - success = set_node_weights( - substrate=self.substrate, - keypair=self.keypair, - node_ids=node_ids, - node_weights=node_weights, - netuid=self.config.settings["subtensor"]["netuid"], - validator_node_id=self.validator_node_id, - version_key=0, - wait_for_inclusion=True, - wait_for_finalization=False, - ) - - if success: - logger.info( - f"Successfully set weights on chain for {len(node_ids)} miners" - ) - else: - logger.error("Failed to set weights on chain") - - except Exception as e: - logger.error(f"Error handling set weights message: {e}") - - async def _send_message(self, message: dict): - """Send message to backend.""" - if not self.websocket: - raise Exception("No WebSocket connection") - - try: - if not self._send_queue: - raise Exception("Outbound queue not initialized") - - queue_size = self._send_queue.qsize() - if queue_size > int( - VALIDATOR_SEND_QUEUE_MAXSIZE * VALIDATOR_SEND_QUEUE_WARN_FRACTION - ): - logger.warning( - "Outbound queue size high (%s/%s)", - queue_size, - VALIDATOR_SEND_QUEUE_MAXSIZE, - ) - - self._send_queue.put_nowait(message) - except asyncio.QueueFull: - logger.error("Outbound message queue full; dropping message") - raise - except Exception as e: - logger.error(f"Failed to enqueue message: {e}") - raise diff --git a/uv.lock b/uv.lock index c88c178..6932fdc 100644 --- a/uv.lock +++ b/uv.lock @@ -971,6 +971,7 @@ dependencies = [ { name = "pycapnp" }, { name = "pydantic" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ray" }, { name = "rich" }, { name = "snowflake-id" }, @@ -1017,6 +1018,7 @@ requires-dist = [ { name = "pycapnp", specifier = ">=2.0.0" }, { name = "pydantic", specifier = ">=2.0.0" }, { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "ray", specifier = ">=2.48.0" }, { name = "rich", specifier = ">=13.0.0" }, { name = "snowflake-id", specifier = "==1.0.2" }, @@ -2050,6 +2052,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0"