Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ AMI_NUM_WORKERS=1
AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2
AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here
AMI_ANTENNA_API_BATCH_SIZE=4
AMI_ANTENNA_SERVICE_NAME="AMI Data Companion"
25 changes: 23 additions & 2 deletions trapdata/antenna/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Antenna API client for fetching jobs and posting results."""

import socket

import requests

from trapdata.antenna.schemas import (
Expand All @@ -11,19 +13,34 @@
from trapdata.common.logs import logger


def get_full_service_name(service_name: str) -> str:
"""Build full service name with hostname.

Args:
service_name: Base service name

Returns:
Full service name with hostname appended
"""
hostname = socket.gethostname()
return f"{service_name} ({hostname})"


def get_jobs(
base_url: str,
auth_token: str,
pipeline_slug: str,
processing_service_name: str,
) -> list[int]:
"""Fetch job ids from the API for the given pipeline.

Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1
Calls: GET {base_url}/jobs?pipeline__slug=<pipeline>&ids_only=1&processing_service_name=<name>

Args:
base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2")
auth_token: API authentication token
pipeline_slug: Pipeline slug to filter jobs
processing_service_name: Name of the processing service

Returns:
List of job ids (possibly empty) on success or error.
Expand All @@ -35,6 +52,7 @@ def get_jobs(
"pipeline__slug": pipeline_slug,
"ids_only": 1,
"incomplete_only": 1,
"processing_service_name": processing_service_name,
"dispatch_mode": JobDispatchMode.ASYNC_API, # Only fetch async_api jobs
}

Expand All @@ -57,6 +75,7 @@ def post_batch_results(
auth_token: str,
job_id: int,
results: list[AntennaTaskResult],
processing_service_name: str,
) -> bool:
"""
Post batch results back to the API.
Expand All @@ -66,6 +85,7 @@ def post_batch_results(
auth_token: API authentication token
job_id: Job ID
results: List of AntennaTaskResult objects
processing_service_name: Name of the processing service

Returns:
True if successful, False otherwise
Expand All @@ -75,7 +95,8 @@ def post_batch_results(

with get_http_session(auth_token) as session:
try:
response = session.post(url, json=payload, timeout=60)
params = {"processing_service_name": processing_service_name}
response = session.post(url, json=payload, params=params, timeout=60)
response.raise_for_status()
logger.info(f"Successfully posted {len(results)} results to {url}")
return True
Expand Down
11 changes: 10 additions & 1 deletion trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
job_id: int,
batch_size: int = 1,
image_transforms: torchvision.transforms.Compose | None = None,
processing_service_name: str = "",
):
"""
Initialize the REST dataset.
Expand All @@ -118,13 +119,15 @@ def __init__(
job_id: The job ID to fetch tasks for
batch_size: Number of tasks to request per batch
image_transforms: Optional transforms to apply to loaded images
processing_service_name: Name of the processing service
"""
super().__init__()
self.base_url = base_url
self.auth_token = auth_token
self.job_id = job_id
self.batch_size = batch_size
self.image_transforms = image_transforms or torchvision.transforms.ToTensor()
self.processing_service_name = processing_service_name

# These are created lazily in _ensure_sessions() because they contain
# unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and
Expand Down Expand Up @@ -167,7 +170,10 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]:
requests.RequestException: If the request fails (network error, etc.)
"""
url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks"
params = {"batch": self.batch_size}
params = {
"batch": self.batch_size,
"processing_service_name": self.processing_service_name,
}

self._ensure_sessions()
assert self._api_session is not None
Expand Down Expand Up @@ -374,6 +380,7 @@ def rest_collate_fn(batch: list[dict]) -> dict:
def get_rest_dataloader(
job_id: int,
settings: "Settings",
processing_service_name: str,
) -> torch.utils.data.DataLoader:
"""Create a DataLoader that fetches tasks from Antenna API.

Expand All @@ -391,12 +398,14 @@ def get_rest_dataloader(
- antenna_api_batch_size (tasks per API call)
- localization_batch_size (images per GPU batch)
- num_workers (DataLoader subprocesses)
- processing_service_name (name of this worker)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
auth_token=settings.antenna_api_auth_token,
job_id=job_id,
batch_size=settings.antenna_api_batch_size,
processing_service_name=processing_service_name,
)

return torch.utils.data.DataLoader(
Expand Down
13 changes: 7 additions & 6 deletions trapdata/antenna/registration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Pipeline registration with Antenna projects."""

import socket

import requests

from trapdata.antenna.client import get_full_service_name
from trapdata.antenna.schemas import (
AsyncPipelineRegistrationRequest,
AsyncPipelineRegistrationResponse,
Expand Down Expand Up @@ -96,13 +95,15 @@ def register_pipelines(
logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set")
return

if service_name is None:
logger.error("Service name is required for registration")
if not service_name or not service_name.strip():
logger.error(
"Service name is required for registration. "
"Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings."
)
return

# Add hostname to service name
hostname = socket.gethostname()
full_service_name = f"{service_name} ({hostname})"
full_service_name = get_full_service_name(service_name)

# Get projects to register for
projects_to_process = []
Expand Down
24 changes: 23 additions & 1 deletion trapdata/antenna/tests/antenna_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,30 @@
_posted_results: dict[int, list[AntennaTaskResult]] = {}
_projects: list[dict] = []
_registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs
_last_get_jobs_service_name: str = ""


@app.get("/api/v2/jobs")
def get_jobs(pipeline__slug: str, ids_only: int, incomplete_only: int):
def get_jobs(
pipeline__slug: str,
ids_only: int,
incomplete_only: int,
processing_service_name: str = "",
):
"""Return available job IDs.

Args:
pipeline__slug: Pipeline slug filter
ids_only: If 1, return only job IDs
incomplete_only: If 1, return only incomplete jobs
processing_service_name: Name of the processing service making the request

Returns:
AntennaJobsListResponse with list of job IDs
"""
global _last_get_jobs_service_name
_last_get_jobs_service_name = processing_service_name

# Return all jobs in queue (for testing, we return all registered jobs)
job_ids = list(_jobs_queue.keys())
results = [AntennaJobListItem(id=job_id) for job_id in job_ids]
Expand Down Expand Up @@ -180,9 +190,21 @@ def get_registered_pipelines(project_id: int) -> list[str]:
return _registered_pipelines.get(project_id, [])


def get_last_get_jobs_service_name() -> str:
"""Return the processing_service_name received by the last get_jobs call.

Returns:
The processing_service_name value from the most recent GET /jobs request,
or an empty string if no request has been made since the last reset().
"""
return _last_get_jobs_service_name


def reset():
"""Clear all state between tests."""
global _last_get_jobs_service_name
_jobs_queue.clear()
_posted_results.clear()
_projects.clear()
_registered_pipelines.clear()
_last_get_jobs_service_name = ""
41 changes: 31 additions & 10 deletions trapdata/antenna/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,12 @@ def test_returns_job_ids(self):
antenna_api_server.setup_job(30, [])

with patch_antenna_api_requests(self.antenna_client):
result = get_jobs("http://testserver/api/v2", "test-token", "moths_2024")
result = get_jobs(
"http://testserver/api/v2", "test-token", "moths_2024", "Test Worker"
)

assert result == [10, 20, 30]
assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker"


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -262,7 +265,10 @@ def test_empty_queue(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 100, self._make_settings()
"quebec_vermont_moths_2023",
100,
self._make_settings(),
"Test Service",
)

assert result is False
Expand All @@ -287,7 +293,10 @@ def test_processes_batch_with_real_inference(self):
# Run worker
with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 101, self._make_settings()
"quebec_vermont_moths_2023",
101,
self._make_settings(),
"Test Service",
)

# Validate processing succeeded
Expand Down Expand Up @@ -322,7 +331,12 @@ def test_handles_failed_items(self):
antenna_api_server.setup_job(job_id=102, tasks=tasks)

with patch_antenna_api_requests(self.antenna_client):
_process_job("quebec_vermont_moths_2023", 102, self._make_settings())
_process_job(
"quebec_vermont_moths_2023",
102,
self._make_settings(),
"Test Service",
)

posted_results = antenna_api_server.get_posted_results(102)
assert len(posted_results) == 1
Expand Down Expand Up @@ -354,7 +368,10 @@ def test_mixed_batch_success_and_failures(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 103, self._make_settings()
"quebec_vermont_moths_2023",
103,
self._make_settings(),
"Test Service",
)

assert result is True
Expand Down Expand Up @@ -444,14 +461,15 @@ def test_full_workflow_with_real_inference(self):

# Step 2: Get jobs
job_ids = get_jobs(
"http://testserver/api/v2",
"test-token",
pipeline_slug,
"http://testserver/api/v2", "test-token", pipeline_slug, "Test Worker"
)
assert 200 in job_ids
assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker"

# Step 3: Process job
result = _process_job(pipeline_slug, 200, self._make_settings())
result = _process_job(
pipeline_slug, 200, self._make_settings(), "Test Worker"
)
assert result is True

# Step 4: Validate results posted
Expand Down Expand Up @@ -498,7 +516,10 @@ def test_multiple_batches_processed(self):

with patch_antenna_api_requests(self.antenna_client):
result = _process_job(
"quebec_vermont_moths_2023", 201, self._make_settings()
"quebec_vermont_moths_2023",
201,
self._make_settings(),
"Test Service",
)

assert result is True
Expand Down
Loading