From a6948de27b38aaba39c8db4dccb748d28717d242 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 14:54:34 -0800 Subject: [PATCH 1/5] Send processing_service_name in requests to antenna --- .env.example | 1 + trapdata/antenna/client.py | 25 +++++++++++++++++++++++-- trapdata/antenna/datasets.py | 11 ++++++++++- trapdata/antenna/registration.py | 6 ++---- trapdata/antenna/worker.py | 17 +++++++++++++++-- trapdata/cli/worker.py | 23 ++++++++++++----------- trapdata/settings.py | 7 +++++++ 7 files changed, 70 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index 2a9178db..4d1a0848 100644 --- a/.env.example +++ b/.env.example @@ -14,3 +14,4 @@ AMI_NUM_WORKERS=1 AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2 AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here AMI_ANTENNA_API_BATCH_SIZE=4 +AMI_ANTENNA_SERVICE_NAME=AMI Data Companion diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index 3e500310..1ac2cf8f 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -1,5 +1,7 @@ """Antenna API client for fetching jobs and posting results.""" +import socket + import requests from trapdata.antenna.schemas import AntennaJobsListResponse, AntennaTaskResult @@ -7,19 +9,34 @@ from trapdata.common.logs import logger +def get_full_service_name(service_name: str) -> str: + """Build full service name with hostname. + + Args: + service_name: Base service name + + Returns: + Full service name with hostname appended + """ + hostname = socket.gethostname() + return f"{service_name} ({hostname})" + + def get_jobs( base_url: str, auth_token: str, pipeline_slug: str, + processing_service_name: str, ) -> list[int]: """Fetch job ids from the API for the given pipeline. - Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1 + Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1&processing_service_name= Args: base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") auth_token: API authentication token pipeline_slug: Pipeline slug to filter jobs + processing_service_name: Name of the processing service Returns: List of job ids (possibly empty) on success or error. @@ -31,6 +48,7 @@ def get_jobs( "pipeline__slug": pipeline_slug, "ids_only": 1, "incomplete_only": 1, + "processing_service_name": processing_service_name, } resp = session.get(url, params=params, timeout=30) @@ -52,6 +70,7 @@ def post_batch_results( auth_token: str, job_id: int, results: list[AntennaTaskResult], + processing_service_name: str, ) -> bool: """ Post batch results back to the API. @@ -61,6 +80,7 @@ def post_batch_results( auth_token: API authentication token job_id: Job ID results: List of AntennaTaskResult objects + processing_service_name: Name of the processing service Returns: True if successful, False otherwise @@ -70,7 +90,8 @@ def post_batch_results( with get_http_session(auth_token) as session: try: - response = session.post(url, json=payload, timeout=60) + params = {"processing_service_name": processing_service_name} + response = session.post(url, json=payload, params=params, timeout=60) response.raise_for_status() logger.info(f"Successfully posted {len(results)} results to {url}") return True diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index faf56b8f..97f06ff6 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -45,6 +45,7 @@ def __init__( job_id: int, batch_size: int = 1, image_transforms: torchvision.transforms.Compose | None = None, + processing_service_name: str = "", ): """ Initialize the REST dataset. @@ -55,12 +56,14 @@ def __init__( job_id: The job ID to fetch tasks for batch_size: Number of tasks to request per batch image_transforms: Optional transforms to apply to loaded images + processing_service_name: Name of the processing service """ super().__init__() self.base_url = base_url self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() + self.processing_service_name = processing_service_name # Create persistent sessions for connection pooling self.api_session = get_http_session(auth_token) @@ -84,7 +87,10 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: requests.RequestException: If the request fails (network error, etc.) """ url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" - params = {"batch": self.batch_size} + params = { + "batch": self.batch_size, + "processing_service_name": self.processing_service_name, + } response = self.api_session.get(url, params=params, timeout=30) response.raise_for_status() @@ -251,6 +257,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: def get_rest_dataloader( job_id: int, settings: "Settings", + processing_service_name: str, ) -> torch.utils.data.DataLoader: """ Create a DataLoader that fetches tasks from Antenna API. @@ -264,12 +271,14 @@ def get_rest_dataloader( Args: job_id: Job ID to fetch tasks for settings: Settings object with antenna_api_* configuration + processing_service_name: Name of the processing service """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, job_id=job_id, batch_size=settings.antenna_api_batch_size, + processing_service_name=processing_service_name, ) return torch.utils.data.DataLoader( diff --git a/trapdata/antenna/registration.py b/trapdata/antenna/registration.py index a78a513f..ae9ada2b 100644 --- a/trapdata/antenna/registration.py +++ b/trapdata/antenna/registration.py @@ -1,9 +1,8 @@ """Pipeline registration with Antenna projects.""" -import socket - import requests +from trapdata.antenna.client import get_full_service_name from trapdata.antenna.schemas import ( AsyncPipelineRegistrationRequest, AsyncPipelineRegistrationResponse, @@ -101,8 +100,7 @@ def register_pipelines( return # Add hostname to service name - hostname = socket.gethostname() - full_service_name = f"{service_name} ({hostname})" + full_service_name = get_full_service_name(service_name) # Get projects to register for projects_to_process = [] diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2fbf3b54..9382d4b7 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -6,7 +6,7 @@ import numpy as np import torch -from trapdata.antenna.client import get_jobs, post_batch_results +from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES @@ -34,6 +34,10 @@ def run_worker(pipelines: list[str]): "Get your auth token from your Antenna project settings." ) + # Build full service name with hostname + full_service_name = get_full_service_name(settings.antenna_service_name) + logger.info(f"Running worker as: {full_service_name}") + while True: # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they @@ -45,6 +49,7 @@ def run_worker(pipelines: list[str]): base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, pipeline_slug=pipeline, + processing_service_name=full_service_name, ) for job_id in jobs: logger.info(f"Processing job {job_id} with pipeline {pipeline}") @@ -53,6 +58,7 @@ def run_worker(pipelines: list[str]): pipeline=pipeline, job_id=job_id, settings=settings, + processing_service_name=full_service_name, ) any_jobs = any_jobs or any_work_done except Exception as e: @@ -72,6 +78,7 @@ def _process_job( pipeline: str, job_id: int, settings: Settings, + processing_service_name: str, ) -> bool: """Run the worker to process images from the REST API queue. @@ -79,11 +86,16 @@ def _process_job( pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) job_id: Job ID to process settings: Settings object with antenna_api_* configuration + processing_service_name: Name of the processing service Returns: True if any work was done, False otherwise """ did_work = False - loader = get_rest_dataloader(job_id=job_id, settings=settings) + loader = get_rest_dataloader( + job_id=job_id, + settings=settings, + processing_service_name=processing_service_name, + ) classifier = None detector = None @@ -232,6 +244,7 @@ def _process_job( settings.antenna_api_auth_token, job_id, batch_results, + processing_service_name, ) st, t = t("Finished posting results") diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py index 19fb97aa..f1b5782e 100644 --- a/trapdata/cli/worker.py +++ b/trapdata/cli/worker.py @@ -16,7 +16,8 @@ def run( list[str] | None, typer.Option( "--pipeline", - help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. Defaults to all pipelines if not specified." + help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. " + "Defaults to all pipelines if not specified.", ), ] = None, ): @@ -49,13 +50,6 @@ def run( @cli.command("register") def register( - name: Annotated[ - str, - typer.Argument( - help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). " - "Hostname will be added automatically.", - ), - ], project: Annotated[ list[int] | None, typer.Option( @@ -70,11 +64,18 @@ def register( This command registers all available pipeline configurations with the Antenna platform for the specified projects (or all accessible projects if none specified). + The service name is read from the AMI_ANTENNA_SERVICE_NAME configuration setting. + Hostname will be added automatically to the service name. + Examples: - ami worker register "AMI Data Companion on DRAC gpu-03" --project 1 --project 2 - ami worker register "My Processing Service" # registers for all accessible projects + ami worker register --project 1 --project 2 + ami worker register # registers for all accessible projects """ from trapdata.antenna.registration import register_pipelines + from trapdata.settings import read_settings + settings = read_settings() project_ids = project if project else [] - register_pipelines(project_ids=project_ids, service_name=name) + register_pipelines( + project_ids=project_ids, service_name=settings.antenna_service_name + ) diff --git a/trapdata/settings.py b/trapdata/settings.py index f4b83f16..b02168ba 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -41,6 +41,7 @@ class Settings(BaseSettings): antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" antenna_api_batch_size: int = 4 + antenna_service_name: str = "AMI Data Companion" @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): @@ -166,6 +167,12 @@ class Config: "kivy_type": "numeric", "kivy_section": "antenna", }, + "antenna_service_name": { + "title": "Antenna Service Name", + "description": "Name for the processing service registration (hostname will be added automatically)", + "kivy_type": "string", + "kivy_section": "antenna", + }, } @classmethod From 2fba0a53eaa2fafe1f6107d106991d4682570404 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 17:26:59 -0800 Subject: [PATCH 2/5] More validation --- .env.example | 2 +- trapdata/antenna/registration.py | 7 +++++-- trapdata/antenna/worker.py | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 4d1a0848..12a099fa 100644 --- a/.env.example +++ b/.env.example @@ -14,4 +14,4 @@ AMI_NUM_WORKERS=1 AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2 AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here AMI_ANTENNA_API_BATCH_SIZE=4 -AMI_ANTENNA_SERVICE_NAME=AMI Data Companion +AMI_ANTENNA_SERVICE_NAME="AMI Data Companion" diff --git a/trapdata/antenna/registration.py b/trapdata/antenna/registration.py index ae9ada2b..4b41e319 100644 --- a/trapdata/antenna/registration.py +++ b/trapdata/antenna/registration.py @@ -95,8 +95,11 @@ def register_pipelines( logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") return - if service_name is None: - logger.error("Service name is required for registration") + if not service_name or not service_name.strip(): + logger.error( + "Service name is required for registration. " + "Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings." + ) return # Add hostname to service name diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 9382d4b7..ddd2c98a 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -34,6 +34,13 @@ def run_worker(pipelines: list[str]): "Get your auth token from your Antenna project settings." ) + # Validate service name + if not settings.antenna_service_name or not settings.antenna_service_name.strip(): + raise ValueError( + "AMI_ANTENNA_SERVICE_NAME configuration setting must be set. " + "Configure it via environment variable or .env file." + ) + # Build full service name with hostname full_service_name = get_full_service_name(settings.antenna_service_name) logger.info(f"Running worker as: {full_service_name}") From 008155b78ab3454bced72f154f4e1381656fa77f Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 4 Feb 2026 17:36:10 -0800 Subject: [PATCH 3/5] Update tests --- trapdata/antenna/tests/test_worker.py | 39 ++++++++++++++++++++------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 4a83958a..1e9913b6 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -194,7 +194,9 @@ def test_returns_job_ids(self): antenna_api_server.setup_job(30, []) with patch_antenna_api_requests(self.antenna_client): - result = get_jobs("http://testserver/api/v2", "test-token", "moths_2024") + result = get_jobs( + "http://testserver/api/v2", "test-token", "moths_2024", "Test Worker" + ) assert result == [10, 20, 30] @@ -237,7 +239,10 @@ def test_empty_queue(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 100, self._make_settings() + "quebec_vermont_moths_2023", + 100, + self._make_settings(), + "Test Service", ) assert result is False @@ -262,7 +267,10 @@ def test_processes_batch_with_real_inference(self): # Run worker with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 101, self._make_settings() + "quebec_vermont_moths_2023", + 101, + self._make_settings(), + "Test Service", ) # Validate processing succeeded @@ -297,7 +305,12 @@ def test_handles_failed_items(self): antenna_api_server.setup_job(job_id=102, tasks=tasks) with patch_antenna_api_requests(self.antenna_client): - _process_job("quebec_vermont_moths_2023", 102, self._make_settings()) + _process_job( + "quebec_vermont_moths_2023", + 102, + self._make_settings(), + "Test Service", + ) posted_results = antenna_api_server.get_posted_results(102) assert len(posted_results) == 1 @@ -329,7 +342,10 @@ def test_mixed_batch_success_and_failures(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 103, self._make_settings() + "quebec_vermont_moths_2023", + 103, + self._make_settings(), + "Test Service", ) assert result is True @@ -419,14 +435,14 @@ def test_full_workflow_with_real_inference(self): # Step 2: Get jobs job_ids = get_jobs( - "http://testserver/api/v2", - "test-token", - pipeline_slug, + "http://testserver/api/v2", "test-token", pipeline_slug, "Test Worker" ) assert 200 in job_ids # Step 3: Process job - result = _process_job(pipeline_slug, 200, self._make_settings()) + result = _process_job( + pipeline_slug, 200, self._make_settings(), "Test Worker" + ) assert result is True # Step 4: Validate results posted @@ -473,7 +489,10 @@ def test_multiple_batches_processed(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 201, self._make_settings() + "quebec_vermont_moths_2023", + 201, + self._make_settings(), + "Test Service", ) assert result is True From 9181b47d4efa1e27f033ab31736e4e44cfe63297 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 18 Feb 2026 09:22:17 -0800 Subject: [PATCH 4/5] fix merge --- trapdata/antenna/worker.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index b1d960fc..5e95f2e1 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -47,12 +47,7 @@ def run_worker(pipelines: list[str]): "Configure it via environment variable or .env file." ) - # Build full service name with hostname - full_service_name = get_full_service_name(settings.antenna_service_name) - logger.info(f"Running worker as: {full_service_name}") - gpu_count = torch.cuda.device_count() - if gpu_count > 1: logger.info(f"Found {gpu_count} GPUs, spawning one AMI worker instance per GPU") # Don't pass settings through mp.spawn — Settings contains enums that From 684519d0eb3631e5f2c96947ed983f49b2d9a641 Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Wed, 18 Feb 2026 10:23:35 -0800 Subject: [PATCH 5/5] Update tests from CR feedback --- trapdata/antenna/tests/antenna_api_server.py | 24 +++++++++++++++++++- trapdata/antenna/tests/test_worker.py | 2 ++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/trapdata/antenna/tests/antenna_api_server.py b/trapdata/antenna/tests/antenna_api_server.py index 18eafcd8..97a389fd 100644 --- a/trapdata/antenna/tests/antenna_api_server.py +++ b/trapdata/antenna/tests/antenna_api_server.py @@ -24,20 +24,30 @@ _posted_results: dict[int, list[AntennaTaskResult]] = {} _projects: list[dict] = [] _registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs +_last_get_jobs_service_name: str = "" @app.get("/api/v2/jobs") -def get_jobs(pipeline__slug: str, ids_only: int, incomplete_only: int): +def get_jobs( + pipeline__slug: str, + ids_only: int, + incomplete_only: int, + processing_service_name: str = "", +): """Return available job IDs. Args: pipeline__slug: Pipeline slug filter ids_only: If 1, return only job IDs incomplete_only: If 1, return only incomplete jobs + processing_service_name: Name of the processing service making the request Returns: AntennaJobsListResponse with list of job IDs """ + global _last_get_jobs_service_name + _last_get_jobs_service_name = processing_service_name + # Return all jobs in queue (for testing, we return all registered jobs) job_ids = list(_jobs_queue.keys()) results = [AntennaJobListItem(id=job_id) for job_id in job_ids] @@ -180,9 +190,21 @@ def get_registered_pipelines(project_id: int) -> list[str]: return _registered_pipelines.get(project_id, []) +def get_last_get_jobs_service_name() -> str: + """Return the processing_service_name received by the last get_jobs call. + + Returns: + The processing_service_name value from the most recent GET /jobs request, + or an empty string if no request has been made since the last reset(). + """ + return _last_get_jobs_service_name + + def reset(): """Clear all state between tests.""" + global _last_get_jobs_service_name _jobs_queue.clear() _posted_results.clear() _projects.clear() _registered_pipelines.clear() + _last_get_jobs_service_name = "" diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index 5232abab..feeb6bff 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -224,6 +224,7 @@ def test_returns_job_ids(self): ) assert result == [10, 20, 30] + assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # --------------------------------------------------------------------------- @@ -463,6 +464,7 @@ def test_full_workflow_with_real_inference(self): "http://testserver/api/v2", "test-token", pipeline_slug, "Test Worker" ) assert 200 in job_ids + assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # Step 3: Process job result = _process_job(