Skip to content

Commit b85ab4d

Browse files
committed
feat: improve typing throughout the codebase
1 parent ba1b778 commit b85ab4d

39 files changed

Lines changed: 826 additions & 564 deletions

AGENTS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,13 @@ Keep it current when commands or conventions change.
9696

9797
### Typing
9898
- Python 3.12 syntax (`list[int]`, `dict[str, float]`, `str | None`).
99-
- Add return types to public functions and methods.
99+
- Add type hints to all function signatures (parameters and return types), not just public ones.
100100
- Prefer `BaseSettings` and `BaseModel` type annotations for config/DTOs.
101101
- `ty` runs in CI; avoid `Any` unless required and explain why in code.
102+
- Use `NewType` to distinguish domain identifiers that share an underlying primitive (e.g., `MinerUID`, `BlockNumber`, `EnvironmentId`, `Hotkey`). This prevents accidental misuse such as passing a `MinerUID` where a `BlockNumber` is expected.
103+
- Centralise shared newtypes, type aliases, enums, and `TypedDict` definitions in `kinitro/types.py`. Import from there rather than re-defining types locally.
104+
- When introducing a new domain concept that is fundamentally a `str`, `int`, or other primitive, create a `NewType` for it in `kinitro/types.py` and use it consistently across signatures, models, and data structures.
105+
- Prefer `TypedDict` or `dataclasses.dataclass` over plain `dict` for structured data with known keys.
102106

103107
### Naming
104108
- `snake_case` for functions, variables, modules.

environments/_template/env.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ async def list_environments(self) -> list[str]:
7575
async def evaluate(
7676
self,
7777
task_id: int,
78+
base_url: str,
7879
seed: int | None = None,
7980
model: str | None = None,
80-
base_url: str | None = None,
8181
env_id: str = "myenv/v0", # TODO: Change default env_id
8282
max_timesteps: int = 500,
8383
action_timeout: float = 0.5,
@@ -108,9 +108,6 @@ async def evaluate(
108108
error=f"Invalid env_id: {env_id}. Must start with 'myenv/'",
109109
)
110110

111-
if base_url is None:
112-
raise ValueError("base_url (miner endpoint) is required")
113-
114111
seed = seed if seed is not None else task_id
115112
start_time = time.time()
116113

environments/metaworld/env.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ async def list_environments(self) -> list[str]:
9898
async def evaluate(
9999
self,
100100
task_id: int,
101+
base_url: str,
101102
seed: int | None = None,
102103
model: str | None = None,
103-
base_url: str | None = None,
104104
env_id: str = "metaworld/pick-place-v3",
105105
max_timesteps: int = 500,
106106
action_timeout: float = 0.5,
@@ -148,9 +148,6 @@ async def evaluate(
148148
error=f"Invalid env_id for MetaWorld container: {env_id}. Must start with 'metaworld/'",
149149
)
150150

151-
if base_url is None:
152-
raise ValueError("base_url (miner endpoint) is required")
153-
154151
if seed is None:
155152
seed = task_id
156153

kinitro/api/routes/health.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
@router.get("/health", response_model=HealthResponse)
20-
async def health_check(session: AsyncSession = Depends(get_session)):
20+
async def health_check(session: AsyncSession = Depends(get_session)) -> HealthResponse:
2121
"""Health check endpoint."""
2222
try:
2323
await session.execute(text("SELECT 1"))
@@ -35,7 +35,7 @@ async def health_check(session: AsyncSession = Depends(get_session)):
3535
async def get_status(
3636
session: AsyncSession = Depends(get_session),
3737
storage: Storage = Depends(get_storage),
38-
):
38+
) -> StatusResponse:
3939
"""Get current backend status."""
4040
# Get current/latest cycles
4141
current_cycle = await storage.get_running_cycle(session)

kinitro/api/routes/miners.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from kinitro.backend.models import EnvironmentInfo, MinerInfo
88
from kinitro.backend.storage import Storage
99
from kinitro.environments import get_all_environment_ids
10+
from kinitro.types import EnvironmentId, EnvStatsEntry
1011

1112
router = APIRouter(prefix="/v1", tags=["Miners & Environments"])
1213

@@ -15,7 +16,7 @@
1516
async def list_miners(
1617
session: AsyncSession = Depends(get_session),
1718
storage: Storage = Depends(get_storage),
18-
):
19+
) -> list[MinerInfo]:
1920
"""List all miners that have been evaluated."""
2021
# Get latest cycle's scores to get miner info
2122
cycle = await storage.get_latest_cycle(session, completed_only=True)
@@ -50,29 +51,32 @@ async def list_miners(
5051
async def list_environments(
5152
session: AsyncSession = Depends(get_session),
5253
storage: Storage = Depends(get_storage),
53-
):
54+
) -> list[EnvironmentInfo]:
5455
"""List all evaluation environments."""
5556
env_ids = get_all_environment_ids()
5657

5758
# Get latest cycle for stats
5859
cycle = await storage.get_latest_cycle(session, completed_only=True)
5960

60-
env_stats: dict[str, dict] = {env_id: {"count": 0, "total_sr": 0.0} for env_id in env_ids}
61+
env_stats: dict[EnvironmentId, EnvStatsEntry] = {
62+
EnvironmentId(env_id): EnvStatsEntry(count=0, total_sr=0.0) for env_id in env_ids
63+
}
6164

6265
if cycle:
6366
scores = await storage.get_scores_for_cycle(session, cycle.id)
6467
for s in scores:
65-
if s.env_id in env_stats:
66-
env_stats[s.env_id]["count"] += 1
67-
env_stats[s.env_id]["total_sr"] += s.success_rate
68+
eid = EnvironmentId(s.env_id)
69+
if eid in env_stats:
70+
env_stats[eid]["count"] += 1
71+
env_stats[eid]["total_sr"] += s.success_rate
6872

6973
result = []
7074
for env_id in env_ids:
7175
parts = env_id.split("/")
7276
env_name = parts[0] if parts else env_id
7377
task_name = parts[1] if len(parts) > 1 else ""
7478

75-
stats = env_stats[env_id]
79+
stats = env_stats[EnvironmentId(env_id)]
7680
avg_sr = stats["total_sr"] / stats["count"] if stats["count"] > 0 else None
7781

7882
result.append(

kinitro/api/routes/scores.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ScoresResponse,
1313
)
1414
from kinitro.backend.storage import Storage
15+
from kinitro.types import EnvironmentId, Hotkey, MinerUID
1516

1617
router = APIRouter(prefix="/v1/scores", tags=["Scores"])
1718

@@ -22,9 +23,9 @@ def _build_scores_response(
2223
"""Build a ScoresResponse from a cycle ORM object and its scores."""
2324
scores = [
2425
MinerScore(
25-
uid=s.uid,
26-
hotkey=s.hotkey,
27-
env_id=s.env_id,
26+
uid=MinerUID(s.uid),
27+
hotkey=Hotkey(s.hotkey),
28+
env_id=EnvironmentId(s.env_id),
2829
success_rate=s.success_rate,
2930
mean_reward=s.mean_reward,
3031
episodes_completed=s.episodes_completed,
@@ -33,11 +34,12 @@ def _build_scores_response(
3334
for s in scores_orm
3435
]
3536

36-
miner_summary: dict[int, dict[str, float]] = {}
37+
miner_summary: dict[MinerUID, dict[EnvironmentId, float]] = {}
3738
for s in scores_orm:
38-
if s.uid not in miner_summary:
39-
miner_summary[s.uid] = {}
40-
miner_summary[s.uid][s.env_id] = s.success_rate
39+
uid = MinerUID(s.uid)
40+
if uid not in miner_summary:
41+
miner_summary[uid] = {}
42+
miner_summary[uid][EnvironmentId(s.env_id)] = s.success_rate
4143

4244
return ScoresResponse(
4345
cycle=EvaluationCycle.model_validate(cycle),
@@ -50,7 +52,7 @@ def _build_scores_response(
5052
async def get_latest_scores(
5153
session: AsyncSession = Depends(get_session),
5254
storage: Storage = Depends(get_storage),
53-
):
55+
) -> ScoresResponse:
5456
"""Get scores from the most recent completed evaluation cycle."""
5557
cycle = await storage.get_latest_cycle(session, completed_only=True)
5658
if cycle is None:
@@ -65,7 +67,7 @@ async def get_scores_for_cycle(
6567
cycle_id: int,
6668
session: AsyncSession = Depends(get_session),
6769
storage: Storage = Depends(get_storage),
68-
):
70+
) -> ScoresResponse:
6971
"""Get scores for a specific evaluation cycle."""
7072
cycle = await storage.get_cycle(session, cycle_id)
7173
if cycle is None:

kinitro/api/routes/tasks.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
TaskSubmitResponse,
1414
)
1515
from kinitro.backend.storage import Storage
16+
from kinitro.types import EnvironmentId, Hotkey, MinerUID, Seed, TaskUUID
1617

1718
router = APIRouter(prefix="/v1/tasks", tags=["Tasks"])
1819

@@ -23,7 +24,7 @@ async def fetch_tasks(
2324
session: AsyncSession = Depends(get_session),
2425
storage: Storage = Depends(get_storage),
2526
_auth: None = Depends(verify_api_key),
26-
):
27+
) -> TaskFetchResponse:
2728
"""
2829
Fetch tasks from the task pool.
2930
@@ -46,15 +47,15 @@ async def fetch_tasks(
4647

4748
tasks = [
4849
Task(
49-
task_uuid=t.task_uuid,
50+
task_uuid=TaskUUID(t.task_uuid),
5051
cycle_id=t.cycle_id,
51-
miner_uid=t.miner_uid,
52-
miner_hotkey=t.miner_hotkey,
52+
miner_uid=MinerUID(t.miner_uid),
53+
miner_hotkey=Hotkey(t.miner_hotkey),
5354
miner_endpoint=t.miner_endpoint,
5455
miner_repo=t.miner_repo,
5556
miner_revision=t.miner_revision,
56-
env_id=t.env_id,
57-
seed=t.seed,
57+
env_id=EnvironmentId(t.env_id),
58+
seed=Seed(t.seed),
5859
status=t.status,
5960
created_at=t.created_at,
6061
)
@@ -73,7 +74,7 @@ async def submit_tasks(
7374
session: AsyncSession = Depends(get_session),
7475
storage: Storage = Depends(get_storage),
7576
_auth: None = Depends(verify_api_key),
76-
):
77+
) -> TaskSubmitResponse:
7778
"""
7879
Submit results for completed tasks.
7980
@@ -117,7 +118,7 @@ async def get_task_stats(
117118
cycle_id: int | None = None,
118119
session: AsyncSession = Depends(get_session),
119120
storage: Storage = Depends(get_storage),
120-
):
121+
) -> TaskPoolStats:
121122
"""
122123
Get task pool statistics.
123124
@@ -129,14 +130,4 @@ async def get_task_stats(
129130
running_cycle = await storage.get_running_cycle(session)
130131
cycle_id = running_cycle.id if running_cycle else None
131132

132-
stats = await storage.get_task_pool_stats(session, cycle_id=cycle_id)
133-
134-
return TaskPoolStats(
135-
total_tasks=stats["total_tasks"],
136-
pending_tasks=stats["pending_tasks"],
137-
assigned_tasks=stats["assigned_tasks"],
138-
completed_tasks=stats["completed_tasks"],
139-
failed_tasks=stats["failed_tasks"],
140-
active_executors=stats["active_executors"],
141-
current_cycle_id=stats["current_cycle_id"],
142-
)
133+
return await storage.get_task_pool_stats(session, cycle_id=cycle_id)

kinitro/backend/models.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from sqlalchemy.dialects.postgresql import JSONB
1111
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
1212

13+
from kinitro.types import EnvironmentId, Hotkey, MinerUID, Seed, TaskUUID
14+
1315

1416
def generate_task_uuid() -> str:
1517
"""Generate a unique task UUID."""
@@ -198,9 +200,9 @@ class HealthResponse(BaseModel):
198200
class MinerScore(BaseModel):
199201
"""Score for one miner on one environment."""
200202

201-
uid: int
202-
hotkey: str
203-
env_id: str
203+
uid: MinerUID
204+
hotkey: Hotkey
205+
env_id: EnvironmentId
204206
success_rate: float
205207
mean_reward: float
206208
episodes_completed: int
@@ -214,7 +216,7 @@ class EvaluationCycle(BaseModel):
214216
block_number: int
215217
started_at: datetime
216218
completed_at: datetime | None
217-
status: str
219+
status: EvaluationCycleStatus
218220
n_miners: int | None
219221
n_environments: int | None
220222
duration_seconds: float | None
@@ -230,7 +232,7 @@ class ScoresResponse(BaseModel):
230232
scores: list[MinerScore]
231233

232234
# Aggregated by miner
233-
miner_summary: dict[int, dict[str, float]] = Field(
235+
miner_summary: dict[MinerUID, dict[EnvironmentId, float]] = Field(
234236
default_factory=dict,
235237
description="Aggregated scores per miner: {uid: {env_id: success_rate}}",
236238
)
@@ -239,7 +241,7 @@ class ScoresResponse(BaseModel):
239241
class WeightsU16(BaseModel):
240242
"""Weights in u16 format for chain submission."""
241243

242-
uids: list[int]
244+
uids: list[MinerUID]
243245
values: list[int]
244246

245247

@@ -249,7 +251,7 @@ class WeightsResponse(BaseModel):
249251
cycle_id: int
250252
block_number: int
251253
timestamp: datetime
252-
weights: dict[int, float] = Field(description="Normalized weights: {uid: weight}")
254+
weights: dict[MinerUID, float] = Field(description="Normalized weights: {uid: weight}")
253255
weights_u16: WeightsU16 = Field(description="Weights formatted for chain submission")
254256
metadata: dict[str, Any] = Field(default_factory=dict)
255257

@@ -261,24 +263,24 @@ class StatusResponse(BaseModel):
261263
last_completed_cycle: EvaluationCycle | None
262264
total_cycles: int
263265
total_miners_evaluated: int
264-
environments: list[str]
266+
environments: list[EnvironmentId]
265267
is_evaluating: bool
266268

267269

268270
class MinerInfo(BaseModel):
269271
"""Information about a miner."""
270272

271-
uid: int
272-
hotkey: str
273+
uid: MinerUID
274+
hotkey: Hotkey
273275
last_evaluated_block: int | None
274276
avg_success_rate: float | None
275-
environments_evaluated: list[str]
277+
environments_evaluated: list[EnvironmentId]
276278

277279

278280
class EnvironmentInfo(BaseModel):
279281
"""Information about an evaluation environment."""
280282

281-
env_id: str
283+
env_id: EnvironmentId
282284
env_name: str
283285
task_name: str
284286
n_evaluations: int
@@ -293,16 +295,16 @@ class EnvironmentInfo(BaseModel):
293295
class Task(BaseModel):
294296
"""A single evaluation task from the task pool."""
295297

296-
task_uuid: str # Unique identifier for API calls
298+
task_uuid: TaskUUID # Unique identifier for API calls
297299
cycle_id: int
298-
miner_uid: int
299-
miner_hotkey: str
300+
miner_uid: MinerUID
301+
miner_hotkey: Hotkey
300302
miner_endpoint: str
301303
miner_repo: str | None = None # HuggingFace repo for verification
302304
miner_revision: str | None = None # HuggingFace revision for verification
303-
env_id: str
304-
seed: int # Deterministic seed for reproducibility
305-
status: str
305+
env_id: EnvironmentId
306+
seed: Seed # Deterministic seed for reproducibility
307+
status: TaskStatus
306308
created_at: datetime
307309

308310
class Config:
@@ -314,7 +316,9 @@ class TaskFetchRequest(BaseModel):
314316

315317
executor_id: str = Field(description="Unique identifier for the executor")
316318
batch_size: int = Field(default=10, ge=1, le=100, description="Number of tasks to fetch")
317-
env_ids: list[str] | None = Field(default=None, description="Filter by environment IDs")
319+
env_ids: list[EnvironmentId] | None = Field(
320+
default=None, description="Filter by environment IDs"
321+
)
318322

319323

320324
class TaskFetchResponse(BaseModel):
@@ -327,7 +331,7 @@ class TaskFetchResponse(BaseModel):
327331
class TaskResult(BaseModel):
328332
"""Result of a single task execution."""
329333

330-
task_uuid: str = Field(description="UUID of the task")
334+
task_uuid: TaskUUID = Field(description="UUID of the task")
331335
success: bool
332336
score: float = Field(default=0.0)
333337
total_reward: float = Field(default=0.0)

0 commit comments

Comments
 (0)