Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ AMI_NUM_WORKERS=1
AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2
AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here
AMI_ANTENNA_API_BATCH_SIZE=4
AMI_ANTENNA_SERVICE_NAME="AMI Data Companion"
25 changes: 23 additions & 2 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Antenna API client for fetching jobs and posting results."""

import socket

import requests

from trapdata.antenna.schemas import (
Expand All @@ -11,19 +13,34 @@
from trapdata.common.logs import logger


def get_full_service_name(service_name: str) -> str:
"""Build full service name with hostname.

Args:
service_name: Base service name

Returns:
Full service name with hostname appended
"""
hostname = socket.gethostname()
return f"{service_name} ({hostname})"


def get_jobs(
base_url: str,
auth_token: str,
pipeline_slug: str,
processing_service_name: str,
) -> list[int]:
"""Fetch job ids from the API for the given pipeline.

Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1
Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1&processing_service_name=<name>

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
pipeline_slug: Pipeline slug to filter jobs
processing_service_name: Name of the processing service

Returns:
List of job ids (possibly empty) on success or error.
Expand All @@ -35,6 +52,7 @@ def get_jobs(
"pipeline__slug": pipeline_slug,
"ids_only": 1,
"incomplete_only": 1,
"processing_service_name": processing_service_name,
"dispatch_mode": JobDispatchMode.ASYNC_API, # Only fetch async_api jobs
}

Expand All @@ -57,6 +75,7 @@ def post_batch_results(
auth_token: str,
job_id: int,
results: list[AntennaTaskResult],
processing_service_name: str,
) -> bool:
"""
Post batch results back to the API.
Expand All @@ -66,6 +85,7 @@ def post_batch_results(
auth_token: API authentication token
job_id: Job ID
results: List of AntennaTaskResult objects
processing_service_name: Name of the processing service

Returns:
True if successful, False otherwise
Expand All @@ -75,7 +95,8 @@ def post_batch_results(

with get_http_session(auth_token) as session:
try:
response = session.post(url, json=payload, timeout=60)
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response.raise_for_status()
logger.info(f"Successfully posted {len(results)} results to {url}")
return True
Expand Down
11 changes: 10 additions & 1 deletion trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
processing_service_name: str = "",
):
"""
Initialize the REST dataset.
Expand All @@ -118,13 +119,15 @@ def __init__(
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
processing_service_name: Name of the processing service
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.processing_service_name = processing_service_name

# These are created lazily in _ensure_sessions() because they contain
# unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and
Expand Down Expand Up @@ -167,7 +170,10 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}

self._ensure_sessions()
assert self._api_session is not None
Expand Down Expand Up @@ -374,6 +380,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
def get_rest_dataloader(
job_id: int,
settings: "Settings",
processing_service_name: str,
) -> torch.utils.data.DataLoader:
"""Create a DataLoader that fetches tasks from Antenna API.

Expand All @@ -391,12 +398,14 @@ def get_rest_dataloader(
- antenna_api_batch_size (tasks per API call)
- localization_batch_size (images per GPU batch)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
processing_service_name=processing_service_name,
)

return torch.utils.data.DataLoader(
Expand Down
13 changes: 7 additions & 6 deletions trapdata/antenna/registration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Pipeline registration with Antenna projects."""

import socket

import requests

from trapdata.antenna.client import get_full_service_name
from trapdata.antenna.schemas import (
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
Expand Down Expand Up @@ -96,13 +95,15 @@ def register_pipelines(
logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set")
return

if service_name is None:
logger.error("Service name is required for registration")
if not service_name or not service_name.strip():
logger.error(
"Service name is required for registration. "
"Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings."
)
return

# Add hostname to service name
hostname = socket.gethostname()
full_service_name = f"{service_name} ({hostname})"
full_service_name = get_full_service_name(service_name)

# Get projects to register for
projects_to_process = []
Expand Down
39 changes: 29 additions & 10 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def test_returns_job_ids(self):
antenna_api_server.setup_job(30, [])

with patch_antenna_api_requests(self.antenna_client):
result = get_jobs("http://testserver/api/v2", "test-token", "moths_2024")
result = get_jobs(
"http://testserver/api/v2", "test-token", "moths_2024", "Test Worker"
)

assert result == [10, 20, 30]

Expand Down Expand Up @@ -262,7 +264,10 @@ def test_empty_queue(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 100, self._make_settings()
"quebec_vermont_moths_2023",
100,
self._make_settings(),
"Test Service",
)

assert result is False
Expand All @@ -287,7 +292,10 @@ def test_processes_batch_with_real_inference(self):
# Run worker
with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 101, self._make_settings()
"quebec_vermont_moths_2023",
101,
self._make_settings(),
"Test Service",
)

# Validate processing succeeded
Expand Down Expand Up @@ -322,7 +330,12 @@ def test_handles_failed_items(self):
antenna_api_server.setup_job(job_id=102, tasks=tasks)

with patch_antenna_api_requests(self.antenna_client):
_process_job("quebec_vermont_moths_2023", 102, self._make_settings())
_process_job(
"quebec_vermont_moths_2023",
102,
self._make_settings(),
"Test Service",
)

posted_results = antenna_api_server.get_posted_results(102)
assert len(posted_results) == 1
Expand Down Expand Up @@ -354,7 +367,10 @@ def test_mixed_batch_success_and_failures(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 103, self._make_settings()
"quebec_vermont_moths_2023",
103,
self._make_settings(),
"Test Service",
)

assert result is True
Expand Down Expand Up @@ -444,14 +460,14 @@ def test_full_workflow_with_real_inference(self):

# Step 2: Get jobs
job_ids = get_jobs(
"http://testserver/api/v2",
"test-token",
pipeline_slug,
"http://testserver/api/v2", "test-token", pipeline_slug, "Test Worker"
)
assert 200 in job_ids

# Step 3: Process job
result = _process_job(pipeline_slug, 200, self._make_settings())
result = _process_job(
pipeline_slug, 200, self._make_settings(), "Test Worker"
)
assert result is True

# Step 4: Validate results posted
Expand Down Expand Up @@ -498,7 +514,10 @@ def test_multiple_batches_processed(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 201, self._make_settings()
"quebec_vermont_moths_2023",
201,
self._make_settings(),
"Test Service",
)

assert result is True
Expand Down
28 changes: 26 additions & 2 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.multiprocessing as mp
import torchvision

from trapdata.antenna.client import get_jobs, post_batch_results
from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results
from trapdata.antenna.datasets import get_rest_dataloader
from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError
from trapdata.api.api import CLASSIFIER_CHOICES
Expand Down Expand Up @@ -40,6 +40,17 @@ def run_worker(pipelines: list[str]):
"Get your auth token from your Antenna project settings."
)

# Validate service name
if not settings.antenna_service_name or not settings.antenna_service_name.strip():
raise ValueError(
"AMI_ANTENNA_SERVICE_NAME configuration setting must be set. "
"Configure it via environment variable or .env file."
)

# Build full service name with hostname
full_service_name = get_full_service_name(settings.antenna_service_name)
logger.info(f"Running worker as: {full_service_name}")

gpu_count = torch.cuda.device_count()

if gpu_count > 1:
Expand Down Expand Up @@ -75,6 +86,10 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
f"AMI worker instance {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}"
)

# Build full service name with hostname
full_service_name = get_full_service_name(settings.antenna_service_name)
logger.info(f"Running worker as: {full_service_name}")

while True:
# TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing
# These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they
Expand All @@ -86,6 +101,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
pipeline_slug=pipeline,
processing_service_name=full_service_name,
)
for job_id in jobs:
logger.info(
Expand All @@ -96,6 +112,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]):
pipeline=pipeline,
job_id=job_id,
settings=settings,
processing_service_name=full_service_name,
)
any_jobs = any_jobs or any_work_done
except Exception as e:
Expand All @@ -117,18 +134,24 @@ def _process_job(
pipeline: str,
job_id: int,
settings: Settings,
processing_service_name: str,
) -> bool:
"""Run the worker to process images from the REST API queue.

Args:
pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024)
job_id: Job ID to process
settings: Settings object with antenna_api_* configuration
processing_service_name: Name of the processing service
Returns:
True if any work was done, False otherwise
"""
did_work = False
loader = get_rest_dataloader(job_id=job_id, settings=settings)
loader = get_rest_dataloader(
job_id=job_id,
settings=settings,
processing_service_name=processing_service_name,
)
classifier = None
detector = None

Expand Down Expand Up @@ -317,6 +340,7 @@ def _process_job(
settings.antenna_api_auth_token,
job_id,
batch_results,
processing_service_name,
)
st, t = t("Finished posting results")

Expand Down
23 changes: 12 additions & 11 deletions trapdata/cli/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def run(
list[str] | None,
typer.Option(
"--pipeline",
help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. Defaults to all pipelines if not specified."
help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. "
"Defaults to all pipelines if not specified.",
),
] = None,
):
Expand Down Expand Up @@ -49,13 +50,6 @@ def run(

@cli.command("register")
def register(
name: Annotated[
str,
typer.Argument(
help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). "
"Hostname will be added automatically.",
),
],
project: Annotated[
list[int] | None,
typer.Option(
Expand All @@ -70,11 +64,18 @@ def register(
This command registers all available pipeline configurations with the Antenna platform
for the specified projects (or all accessible projects if none specified).

The service name is read from the AMI_ANTENNA_SERVICE_NAME configuration setting.
Hostname will be added automatically to the service name.

Examples:
ami worker register "AMI Data Companion on DRAC gpu-03" --project 1 --project 2
ami worker register "My Processing Service" # registers for all accessible projects
ami worker register --project 1 --project 2
ami worker register # registers for all accessible projects
"""
from trapdata.antenna.registration import register_pipelines
from trapdata.settings import read_settings

settings = read_settings()
project_ids = project if project else []
register_pipelines(project_ids=project_ids, service_name=name)
register_pipelines(
project_ids=project_ids, service_name=settings.antenna_service_name
)
Loading