Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ AMI_NUM_WORKERS=1
AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2
AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here
AMI_ANTENNA_API_BATCH_SIZE=4
AMI_ANTENNA_SERVICE_NAME=AMI Data Companion
25 changes: 23 additions & 2 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
"""Antenna API client for fetching jobs and posting results."""

import socket

import requests

from trapdata.antenna.schemas import AntennaJobsListResponse, AntennaTaskResult
from trapdata.api.utils import get_http_session
from trapdata.common.logs import logger


def get_full_service_name(service_name: str) -> str:
"""Build full service name with hostname.

Args:
service_name: Base service name

Returns:
Full service name with hostname appended
"""
hostname = socket.gethostname()
return f"{service_name} ({hostname})"


def get_jobs(
base_url: str,
auth_token: str,
pipeline_slug: str,
processing_service_name: str,
) -> list[int]:
"""Fetch job ids from the API for the given pipeline.

Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1
Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1&processing_service_name=<name>

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
pipeline_slug: Pipeline slug to filter jobs
processing_service_name: Name of the processing service

Returns:
List of job ids (possibly empty) on success or error.
Expand All @@ -31,6 +48,7 @@ def get_jobs(
"pipeline__slug": pipeline_slug,
"ids_only": 1,
"incomplete_only": 1,
"processing_service_name": processing_service_name,
}

resp = session.get(url, params=params, timeout=30)
Expand All @@ -52,6 +70,7 @@ def post_batch_results(
auth_token: str,
job_id: int,
results: list[AntennaTaskResult],
processing_service_name: str,
) -> bool:
"""
Post batch results back to the API.
Expand All @@ -61,6 +80,7 @@ def post_batch_results(
auth_token: API authentication token
job_id: Job ID
results: List of AntennaTaskResult objects
processing_service_name: Name of the processing service

Returns:
True if successful, False otherwise
Expand All @@ -70,7 +90,8 @@ def post_batch_results(

with get_http_session(auth_token) as session:
try:
response = session.post(url, json=payload, timeout=60)
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response.raise_for_status()
logger.info(f"Successfully posted {len(results)} results to {url}")
return True
Expand Down
11 changes: 10 additions & 1 deletion trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
processing_service_name: str = "",
):
"""
Initialize the REST dataset.
Expand All @@ -55,12 +56,14 @@ def __init__(
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
processing_service_name: Name of the processing service
"""
super().__init__()
self.base_url = base_url
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.processing_service_name = processing_service_name

# Create persistent sessions for connection pooling
self.api_session = get_http_session(auth_token)
Expand All @@ -84,7 +87,10 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}

response = self.api_session.get(url, params=params, timeout=30)
response.raise_for_status()
Expand Down Expand Up @@ -251,6 +257,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
def get_rest_dataloader(
job_id: int,
settings: "Settings",
processing_service_name: str,
) -> torch.utils.data.DataLoader:
"""
Create a DataLoader that fetches tasks from Antenna API.
Expand All @@ -264,12 +271,14 @@ def get_rest_dataloader(
Args:
job_id: Job ID to fetch tasks for
settings: Settings object with antenna_api_* configuration
processing_service_name: Name of the processing service
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
processing_service_name=processing_service_name,
)

return torch.utils.data.DataLoader(
Expand Down
6 changes: 2 additions & 4 deletions trapdata/antenna/registration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Pipeline registration with Antenna projects."""

import socket

import requests

from trapdata.antenna.client import get_full_service_name
from trapdata.antenna.schemas import (
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
Expand Down Expand Up @@ -101,8 +100,7 @@ def register_pipelines(
return

# Add hostname to service name
hostname = socket.gethostname()
full_service_name = f"{service_name} ({hostname})"
full_service_name = get_full_service_name(service_name)

# Get projects to register for
projects_to_process = []
Expand Down
17 changes: 15 additions & 2 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch

from trapdata.antenna.client import get_jobs, post_batch_results
from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results
from trapdata.antenna.datasets import get_rest_dataloader
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
from trapdata.api.api import CLASSIFIER_CHOICES
Expand Down Expand Up @@ -34,6 +34,10 @@ def run_worker(pipelines: list[str]):
"Get your auth token from your Antenna project settings."
)

# Build full service name with hostname
full_service_name = get_full_service_name(settings.antenna_service_name)
logger.info(f"Running worker as: {full_service_name}")

while True:
# TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing
# These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they
Expand All @@ -45,6 +49,7 @@ def run_worker(pipelines: list[str]):
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
pipeline_slug=pipeline,
processing_service_name=full_service_name,
)
for job_id in jobs:
logger.info(f"Processing job {job_id} with pipeline {pipeline}")
Expand All @@ -53,6 +58,7 @@ def run_worker(pipelines: list[str]):
pipeline=pipeline,
job_id=job_id,
settings=settings,
processing_service_name=full_service_name,
)
any_jobs = any_jobs or any_work_done
except Exception as e:
Expand All @@ -72,18 +78,24 @@ def _process_job(
pipeline: str,
job_id: int,
settings: Settings,
processing_service_name: str,
) -> bool:
"""Run the worker to process images from the REST API queue.

Args:
pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024)
job_id: Job ID to process
settings: Settings object with antenna_api_* configuration
processing_service_name: Name of the processing service
Returns:
True if any work was done, False otherwise
"""
did_work = False
loader = get_rest_dataloader(job_id=job_id, settings=settings)
loader = get_rest_dataloader(
job_id=job_id,
settings=settings,
processing_service_name=processing_service_name,
)
classifier = None
detector = None

Expand Down Expand Up @@ -232,6 +244,7 @@ def _process_job(
settings.antenna_api_auth_token,
job_id,
batch_results,
processing_service_name,
)
st, t = t("Finished posting results")

Expand Down
23 changes: 12 additions & 11 deletions trapdata/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def run(
list[str] | None,
typer.Option(
"--pipeline",
help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. Defaults to all pipelines if not specified."
help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. "
"Defaults to all pipelines if not specified.",
),
] = None,
):
Expand Down Expand Up @@ -49,13 +50,6 @@ def run(

@cli.command("register")
def register(
name: Annotated[
str,
typer.Argument(
help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). "
"Hostname will be added automatically.",
),
],
project: Annotated[
list[int] | None,
typer.Option(
Expand All @@ -70,11 +64,18 @@ def register(
This command registers all available pipeline configurations with the Antenna platform
for the specified projects (or all accessible projects if none specified).

The service name is read from the AMI_ANTENNA_SERVICE_NAME configuration setting.
Hostname will be added automatically to the service name.

Examples:
ami worker register "AMI Data Companion on DRAC gpu-03" --project 1 --project 2
ami worker register "My Processing Service" # registers for all accessible projects
ami worker register --project 1 --project 2
ami worker register # registers for all accessible projects
"""
from trapdata.antenna.registration import register_pipelines
from trapdata.settings import read_settings

settings = read_settings()
project_ids = project if project else []
register_pipelines(project_ids=project_ids, service_name=name)
register_pipelines(
project_ids=project_ids, service_name=settings.antenna_service_name
)
7 changes: 7 additions & 0 deletions trapdata/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Settings(BaseSettings):
antenna_api_base_url: str = "http://localhost:8000/api/v2"
antenna_api_auth_token: str = ""
antenna_api_batch_size: int = 4
antenna_service_name: str = "AMI Data Companion"

@pydantic.field_validator("image_base_path", "user_data_path")
def validate_path(cls, v):
Expand Down Expand Up @@ -166,6 +167,12 @@ class Config:
"kivy_type": "numeric",
"kivy_section": "antenna",
},
"antenna_service_name": {
"title": "Antenna Service Name",
"description": "Name for the processing service registration (hostname will be added automatically)",
"kivy_type": "string",
"kivy_section": "antenna",
},
}

@classmethod
Expand Down
Loading