Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions database/unified_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@
upload_job_and_trial_records,
upload_traces_to_hf,
register_benchmark_and_tasks_from_job,
# Utility functions
calculate_standard_error,
# Pending Job Status functions
create_job_entry_pending,
update_job_status_to_started,
get_job_by_model_benchmark,
get_latest_job_for_model_benchmark,
create_job_entry_started,
)
from .models import (
DatasetModel,
Expand Down Expand Up @@ -149,4 +157,11 @@
"upload_job_and_trial_records",
"upload_traces_to_hf",
"register_benchmark_and_tasks_from_job",
"calculate_standard_error",
# Pending Job Status exports
"create_job_entry_pending",
"update_job_status_to_started",
"get_job_by_model_benchmark",
"get_latest_job_for_model_benchmark",
"create_job_entry_started",
]
17 changes: 15 additions & 2 deletions database/unified_db/complete_schema.sql
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
-- Complete Schema for OT-Agents Registration System
-- Merged with DC-Agents additions (duplicate_of support)
-- Run this file to set up all required tables

-- Enable UUID extension
Expand Down Expand Up @@ -53,6 +54,7 @@ CREATE TABLE IF NOT EXISTS models (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name TEXT NOT NULL,
base_model_id UUID REFERENCES models(id),
duplicate_of UUID REFERENCES models(id) ON DELETE RESTRICT,
created_by TEXT NOT NULL,
creation_location TEXT NOT NULL,
creation_time TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
Expand All @@ -70,7 +72,10 @@ CREATE TABLE IF NOT EXISTS models (
agent_id UUID REFERENCES agents(id) NOT NULL,
training_type TEXT CHECK (training_type IN ('SFT', 'RL')),
traces_location_s3 TEXT,
description TEXT
description TEXT,

-- Prevent self-reference for duplicate_of
CONSTRAINT models_no_self_duplicate CHECK (duplicate_of IS NULL OR duplicate_of != id)
);

-- Indexes for models
Expand All @@ -79,6 +84,7 @@ CREATE INDEX idx_models_created_by ON models(created_by);
CREATE INDEX idx_models_agent_id ON models(agent_id);
CREATE INDEX idx_models_dataset_id ON models(dataset_id);
CREATE INDEX idx_models_base_model_id ON models(base_model_id);
CREATE INDEX idx_models_duplicate_of ON models(duplicate_of);
CREATE INDEX idx_models_training_type ON models(training_type);
CREATE INDEX idx_models_creation_time ON models(creation_time DESC);
CREATE INDEX idx_models_training_start ON models(training_start DESC);
Expand All @@ -90,15 +96,20 @@ CREATE TABLE IF NOT EXISTS benchmarks (
name TEXT NOT NULL,
benchmark_version_hash CHAR(64),
is_external BOOLEAN NOT NULL DEFAULT false,
duplicate_of UUID REFERENCES benchmarks(id) ON DELETE RESTRICT,
external_link TEXT,
description TEXT,
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),

-- Prevent self-reference for duplicate_of
CONSTRAINT benchmarks_no_self_duplicate CHECK (duplicate_of IS NULL OR duplicate_of != id)
);

-- Indexes for benchmarks
CREATE INDEX idx_benchmarks_name ON benchmarks(name);
CREATE INDEX idx_benchmarks_benchmark_version_hash ON benchmarks(benchmark_version_hash);
CREATE INDEX idx_benchmarks_is_external ON benchmarks(is_external);
CREATE INDEX idx_benchmarks_duplicate_of ON benchmarks(duplicate_of);
CREATE INDEX idx_benchmarks_updated_at ON benchmarks(updated_at DESC);

-- ==================== UPDATE TRIGGERS ====================
Expand Down Expand Up @@ -182,12 +193,14 @@ COMMENT ON TABLE models IS 'Table storing ML model metadata and training informa
COMMENT ON COLUMN models.training_type IS 'Type of training: SFT (Supervised Fine-Tuning) or RL (Reinforcement Learning)';
COMMENT ON COLUMN models.is_external IS 'Whether this model is external (e.g., from HuggingFace)';
COMMENT ON COLUMN models.training_parameters IS 'JSON containing all training hyperparameters and configuration';
COMMENT ON COLUMN models.duplicate_of IS 'UUID of the canonical model this entry is a duplicate of (for deduplication tracking)';

-- Benchmarks table
COMMENT ON TABLE benchmarks IS 'Table storing evaluation benchmark metadata';
COMMENT ON COLUMN benchmarks.name IS 'Name of the benchmark';
COMMENT ON COLUMN benchmarks.benchmark_version_hash IS 'SHA-256 hash of the benchmark version (64 characters)';
COMMENT ON COLUMN benchmarks.is_external IS 'Whether this benchmark is external (not hosted internally)';
COMMENT ON COLUMN benchmarks.duplicate_of IS 'UUID of the canonical benchmark this entry is a duplicate of (for deduplication tracking)';
COMMENT ON COLUMN benchmarks.external_link IS 'Link to external benchmark if applicable';
COMMENT ON COLUMN benchmarks.description IS 'Description of the benchmark and its purpose';
COMMENT ON COLUMN benchmarks.updated_at IS 'Timestamp when the benchmark was last updated';
Expand Down
9 changes: 7 additions & 2 deletions database/unified_db/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def create_supabase_client(use_admin: bool = False) -> Client:
print("⚠️ Admin access requested but no service role key found")
print(" Some operations may fail due to RLS policies")

# Create client (v2 API doesn't use ClientOptions the same way)
return create_client(supabase_config.supabase_url, key)
# Create client with timeout options
options = ClientOptions(
postgrest_client_timeout=30,
storage_client_timeout=30
)

return create_client(supabase_config.supabase_url, key, options)


def get_default_client() -> Client:
Expand Down
28 changes: 24 additions & 4 deletions database/unified_db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ModelModel(BaseModel):
id: Optional[UUID] = Field(default_factory=uuid4)
name: str
base_model_id: Optional[UUID] = None
duplicate_of: Optional[UUID] = None # Reference to canonical model this is a duplicate of
created_by: str
creation_location: str
creation_time: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc))
Expand Down Expand Up @@ -133,6 +134,10 @@ def serialize_id(self, value: Optional[UUID]) -> Optional[str]:
def serialize_base_model_id(self, value: Optional[UUID]) -> Optional[str]:
return str(value) if value else None

@field_serializer('duplicate_of')
def serialize_duplicate_of(self, value: Optional[UUID]) -> Optional[str]:
return str(value) if value else None

@field_serializer('dataset_id')
def serialize_dataset_id(self, value: Optional[UUID]) -> Optional[str]:
return str(value) if value else None
Expand Down Expand Up @@ -171,6 +176,7 @@ def clean_model_metadata(model_data: Dict[str, Any]) -> Dict[str, Any]:
'id': str(model_data.get('id')) if model_data.get('id') else None,
'name': model_data.get('name'),
'base_model_id': str(model_data.get('base_model_id')) if model_data.get('base_model_id') else None,
'duplicate_of': str(model_data.get('duplicate_of')) if model_data.get('duplicate_of') else None,
'created_by': model_data.get('created_by'),
'creation_location': model_data.get('creation_location'),
'creation_time': model_data.get('creation_time'),
Expand Down Expand Up @@ -241,6 +247,7 @@ class BenchmarkModel(BaseModel):
name: str
benchmark_version_hash: Optional[str] = Field(None, max_length=64)
is_external: bool = False
duplicate_of: Optional[UUID] = None # Reference to canonical benchmark this is a duplicate of
external_link: Optional[str] = None
description: Optional[str] = None
updated_at: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc))
Expand All @@ -256,6 +263,10 @@ def validate_benchmark_version_hash(cls, v: Optional[str]) -> Optional[str]:
def serialize_id(self, value: Optional[UUID]) -> Optional[str]:
return str(value) if value else None

@field_serializer('duplicate_of')
def serialize_duplicate_of(self, value: Optional[UUID]) -> Optional[str]:
return str(value) if value else None

@field_serializer('updated_at')
def serialize_updated_at(self, value: Optional[datetime]) -> Optional[str]:
return value.isoformat() if value else None
Expand All @@ -271,6 +282,7 @@ def clean_benchmark_metadata(benchmark_data: Dict[str, Any]) -> Dict[str, Any]:
'name': benchmark_data.get('name'),
'benchmark_version_hash': benchmark_data.get('benchmark_version_hash'),
'is_external': benchmark_data.get('is_external'),
'duplicate_of': str(benchmark_data.get('duplicate_of')) if benchmark_data.get('duplicate_of') else None,
'external_link': benchmark_data.get('external_link'),
'description': benchmark_data.get('description'),
'updated_at': benchmark_data.get('updated_at')
Expand Down Expand Up @@ -363,18 +375,20 @@ class SandboxJobModel(BaseModel):
username: str
started_at: Optional[datetime] = None
ended_at: Optional[datetime] = None
submitted_at: Optional[datetime] = None # When submitted to SLURM queue
slurm_job_id: Optional[str] = None # SLURM job ID for tracking
git_commit_id: Optional[str] = None
package_version: Optional[str] = None
n_trials: int
config: Dict[str, Any]
n_trials: Optional[int] = None # Made optional for Pending jobs
config: Optional[Dict[str, Any]] = None # Made optional for Pending jobs
metrics: Optional[Dict[str, Any]] = None
stats: Optional[Dict[str, Any]] = None
agent_id: UUID
model_id: UUID
benchmark_id: UUID
n_rep_eval: int
n_rep_eval: Optional[int] = None # Made optional for Pending jobs
hf_traces_link: Optional[str] = None
job_status: Optional[str] = None
job_status: Optional[str] = None # "Pending", "Started", "Finished"

@field_validator('git_commit_id', 'package_version')
@classmethod
Expand Down Expand Up @@ -410,6 +424,10 @@ def serialize_started_at(self, value: Optional[datetime]) -> Optional[str]:
def serialize_ended_at(self, value: Optional[datetime]) -> Optional[str]:
return value.isoformat() if value else None

@field_serializer('submitted_at')
def serialize_submitted_at(self, value: Optional[datetime]) -> Optional[str]:
return value.isoformat() if value else None


def clean_sandbox_job_metadata(job_data: Dict[str, Any]) -> Dict[str, Any]:
"""Clean sandbox job metadata for API responses."""
Expand All @@ -423,6 +441,8 @@ def clean_sandbox_job_metadata(job_data: Dict[str, Any]) -> Dict[str, Any]:
'username': job_data.get('username'),
'started_at': job_data.get('started_at'),
'ended_at': job_data.get('ended_at'),
'submitted_at': job_data.get('submitted_at'),
'slurm_job_id': job_data.get('slurm_job_id'),
'git_commit_id': job_data.get('git_commit_id'),
'package_version': job_data.get('package_version'),
'n_trials': job_data.get('n_trials'),
Expand Down
2 changes: 1 addition & 1 deletion database/unified_db/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Essential dependencies for DC-Agents Dataset Registration

# Database and API
supabase==2.22.3
supabase>=2.0.0,<3.0.0

# Data processing
pandas>=2.0.0
Expand Down
Loading