diff --git a/.env.example b/.env.example index 2a9178d..12a099f 100644 --- a/.env.example +++ b/.env.example @@ -14,3 +14,4 @@ AMI_NUM_WORKERS=1 AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2 AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here AMI_ANTENNA_API_BATCH_SIZE=4 +AMI_ANTENNA_SERVICE_NAME="AMI Data Companion" diff --git a/trapdata/antenna/client.py b/trapdata/antenna/client.py index 9ff7f29..bc97367 100644 --- a/trapdata/antenna/client.py +++ b/trapdata/antenna/client.py @@ -1,5 +1,7 @@ """Antenna API client for fetching jobs and posting results.""" +import socket + import requests from trapdata.antenna.schemas import ( @@ -11,19 +13,34 @@ from trapdata.common.logs import logger +def get_full_service_name(service_name: str) -> str: + """Build full service name with hostname. + + Args: + service_name: Base service name + + Returns: + Full service name with hostname appended + """ + hostname = socket.gethostname() + return f"{service_name} ({hostname})" + + def get_jobs( base_url: str, auth_token: str, pipeline_slug: str, + processing_service_name: str, ) -> list[int]: """Fetch job ids from the API for the given pipeline. - Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1 + Calls: GET {base_url}/jobs?pipeline__slug=&ids_only=1&processing_service_name= Args: base_url: Antenna API base URL (e.g., "http://localhost:8000/api/v2") auth_token: API authentication token pipeline_slug: Pipeline slug to filter jobs + processing_service_name: Name of the processing service Returns: List of job ids (possibly empty) on success or error. @@ -35,6 +52,7 @@ def get_jobs( "pipeline__slug": pipeline_slug, "ids_only": 1, "incomplete_only": 1, + "processing_service_name": processing_service_name, "dispatch_mode": JobDispatchMode.ASYNC_API, # Only fetch async_api jobs } @@ -57,6 +75,7 @@ def post_batch_results( auth_token: str, job_id: int, results: list[AntennaTaskResult], + processing_service_name: str, ) -> bool: """ Post batch results back to the API. @@ -66,6 +85,7 @@ def post_batch_results( auth_token: API authentication token job_id: Job ID results: List of AntennaTaskResult objects + processing_service_name: Name of the processing service Returns: True if successful, False otherwise @@ -75,7 +95,8 @@ def post_batch_results( with get_http_session(auth_token) as session: try: - response = session.post(url, json=payload, timeout=60) + params = {"processing_service_name": processing_service_name} + response = session.post(url, json=payload, params=params, timeout=60) response.raise_for_status() logger.info(f"Successfully posted {len(results)} results to {url}") return True diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 25b33ba..bb041f5 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -108,6 +108,7 @@ def __init__( job_id: int, batch_size: int = 1, image_transforms: torchvision.transforms.Compose | None = None, + processing_service_name: str = "", ): """ Initialize the REST dataset. @@ -118,6 +119,7 @@ def __init__( job_id: The job ID to fetch tasks for batch_size: Number of tasks to request per batch image_transforms: Optional transforms to apply to loaded images + processing_service_name: Name of the processing service """ super().__init__() self.base_url = base_url @@ -125,6 +127,7 @@ def __init__( self.job_id = job_id self.batch_size = batch_size self.image_transforms = image_transforms or torchvision.transforms.ToTensor() + self.processing_service_name = processing_service_name # These are created lazily in _ensure_sessions() because they contain # unpicklable objects (ThreadPoolExecutor has a SimpleQueue) and @@ -167,7 +170,10 @@ def _fetch_tasks(self) -> list[AntennaPipelineProcessingTask]: requests.RequestException: If the request fails (network error, etc.) """ url = f"{self.base_url.rstrip('/')}/jobs/{self.job_id}/tasks" - params = {"batch": self.batch_size} + params = { + "batch": self.batch_size, + "processing_service_name": self.processing_service_name, + } self._ensure_sessions() assert self._api_session is not None @@ -374,6 +380,7 @@ def rest_collate_fn(batch: list[dict]) -> dict: def get_rest_dataloader( job_id: int, settings: "Settings", + processing_service_name: str, ) -> torch.utils.data.DataLoader: """Create a DataLoader that fetches tasks from Antenna API. @@ -391,12 +398,14 @@ def get_rest_dataloader( - antenna_api_batch_size (tasks per API call) - localization_batch_size (images per GPU batch) - num_workers (DataLoader subprocesses) + - processing_service_name (name of this worker) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, job_id=job_id, batch_size=settings.antenna_api_batch_size, + processing_service_name=processing_service_name, ) return torch.utils.data.DataLoader( diff --git a/trapdata/antenna/registration.py b/trapdata/antenna/registration.py index a78a513..4b41e31 100644 --- a/trapdata/antenna/registration.py +++ b/trapdata/antenna/registration.py @@ -1,9 +1,8 @@ """Pipeline registration with Antenna projects.""" -import socket - import requests +from trapdata.antenna.client import get_full_service_name from trapdata.antenna.schemas import ( AsyncPipelineRegistrationRequest, AsyncPipelineRegistrationResponse, @@ -96,13 +95,15 @@ def register_pipelines( logger.error("AMI_ANTENNA_API_AUTH_TOKEN environment variable not set") return - if service_name is None: - logger.error("Service name is required for registration") + if not service_name or not service_name.strip(): + logger.error( + "Service name is required for registration. " + "Configure AMI_ANTENNA_SERVICE_NAME via environment variable, .env file, or Kivy settings." + ) return # Add hostname to service name - hostname = socket.gethostname() - full_service_name = f"{service_name} ({hostname})" + full_service_name = get_full_service_name(service_name) # Get projects to register for projects_to_process = [] diff --git a/trapdata/antenna/tests/antenna_api_server.py b/trapdata/antenna/tests/antenna_api_server.py index 18eafcd..97a389f 100644 --- a/trapdata/antenna/tests/antenna_api_server.py +++ b/trapdata/antenna/tests/antenna_api_server.py @@ -24,20 +24,30 @@ _posted_results: dict[int, list[AntennaTaskResult]] = {} _projects: list[dict] = [] _registered_pipelines: dict[int, list[str]] = {} # project_id -> pipeline slugs +_last_get_jobs_service_name: str = "" @app.get("/api/v2/jobs") -def get_jobs(pipeline__slug: str, ids_only: int, incomplete_only: int): +def get_jobs( + pipeline__slug: str, + ids_only: int, + incomplete_only: int, + processing_service_name: str = "", +): """Return available job IDs. Args: pipeline__slug: Pipeline slug filter ids_only: If 1, return only job IDs incomplete_only: If 1, return only incomplete jobs + processing_service_name: Name of the processing service making the request Returns: AntennaJobsListResponse with list of job IDs """ + global _last_get_jobs_service_name + _last_get_jobs_service_name = processing_service_name + # Return all jobs in queue (for testing, we return all registered jobs) job_ids = list(_jobs_queue.keys()) results = [AntennaJobListItem(id=job_id) for job_id in job_ids] @@ -180,9 +190,21 @@ def get_registered_pipelines(project_id: int) -> list[str]: return _registered_pipelines.get(project_id, []) +def get_last_get_jobs_service_name() -> str: + """Return the processing_service_name received by the last get_jobs call. + + Returns: + The processing_service_name value from the most recent GET /jobs request, + or an empty string if no request has been made since the last reset(). + """ + return _last_get_jobs_service_name + + def reset(): """Clear all state between tests.""" + global _last_get_jobs_service_name _jobs_queue.clear() _posted_results.clear() _projects.clear() _registered_pipelines.clear() + _last_get_jobs_service_name = "" diff --git a/trapdata/antenna/tests/test_worker.py b/trapdata/antenna/tests/test_worker.py index e9919ca..feeb6bf 100644 --- a/trapdata/antenna/tests/test_worker.py +++ b/trapdata/antenna/tests/test_worker.py @@ -219,9 +219,12 @@ def test_returns_job_ids(self): antenna_api_server.setup_job(30, []) with patch_antenna_api_requests(self.antenna_client): - result = get_jobs("http://testserver/api/v2", "test-token", "moths_2024") + result = get_jobs( + "http://testserver/api/v2", "test-token", "moths_2024", "Test Worker" + ) assert result == [10, 20, 30] + assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # --------------------------------------------------------------------------- @@ -262,7 +265,10 @@ def test_empty_queue(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 100, self._make_settings() + "quebec_vermont_moths_2023", + 100, + self._make_settings(), + "Test Service", ) assert result is False @@ -287,7 +293,10 @@ def test_processes_batch_with_real_inference(self): # Run worker with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 101, self._make_settings() + "quebec_vermont_moths_2023", + 101, + self._make_settings(), + "Test Service", ) # Validate processing succeeded @@ -322,7 +331,12 @@ def test_handles_failed_items(self): antenna_api_server.setup_job(job_id=102, tasks=tasks) with patch_antenna_api_requests(self.antenna_client): - _process_job("quebec_vermont_moths_2023", 102, self._make_settings()) + _process_job( + "quebec_vermont_moths_2023", + 102, + self._make_settings(), + "Test Service", + ) posted_results = antenna_api_server.get_posted_results(102) assert len(posted_results) == 1 @@ -354,7 +368,10 @@ def test_mixed_batch_success_and_failures(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 103, self._make_settings() + "quebec_vermont_moths_2023", + 103, + self._make_settings(), + "Test Service", ) assert result is True @@ -444,14 +461,15 @@ def test_full_workflow_with_real_inference(self): # Step 2: Get jobs job_ids = get_jobs( - "http://testserver/api/v2", - "test-token", - pipeline_slug, + "http://testserver/api/v2", "test-token", pipeline_slug, "Test Worker" ) assert 200 in job_ids + assert antenna_api_server.get_last_get_jobs_service_name() == "Test Worker" # Step 3: Process job - result = _process_job(pipeline_slug, 200, self._make_settings()) + result = _process_job( + pipeline_slug, 200, self._make_settings(), "Test Worker" + ) assert result is True # Step 4: Validate results posted @@ -498,7 +516,10 @@ def test_multiple_batches_processed(self): with patch_antenna_api_requests(self.antenna_client): result = _process_job( - "quebec_vermont_moths_2023", 201, self._make_settings() + "quebec_vermont_moths_2023", + 201, + self._make_settings(), + "Test Service", ) assert result is True diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index bba7905..5e95f2e 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -8,7 +8,7 @@ import torch.multiprocessing as mp import torchvision -from trapdata.antenna.client import get_jobs, post_batch_results +from trapdata.antenna.client import get_full_service_name, get_jobs, post_batch_results from trapdata.antenna.datasets import get_rest_dataloader from trapdata.antenna.schemas import AntennaTaskResult, AntennaTaskResultError from trapdata.api.api import CLASSIFIER_CHOICES @@ -40,8 +40,14 @@ def run_worker(pipelines: list[str]): "Get your auth token from your Antenna project settings." ) - gpu_count = torch.cuda.device_count() + # Validate service name + if not settings.antenna_service_name or not settings.antenna_service_name.strip(): + raise ValueError( + "AMI_ANTENNA_SERVICE_NAME configuration setting must be set. " + "Configure it via environment variable or .env file." + ) + gpu_count = torch.cuda.device_count() if gpu_count > 1: logger.info(f"Found {gpu_count} GPUs, spawning one AMI worker instance per GPU") # Don't pass settings through mp.spawn — Settings contains enums that @@ -75,6 +81,10 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): f"AMI worker instance {gpu_id} pinned to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}" ) + # Build full service name with hostname + full_service_name = get_full_service_name(settings.antenna_service_name) + logger.info(f"Running worker as: {full_service_name}") + while True: # TODO CGJS: Support pulling and prioritizing single image tasks, which are used in interactive testing # These should probably come from a dedicated endpoint and should preempt batch jobs under the assumption that they @@ -86,6 +96,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): base_url=settings.antenna_api_base_url, auth_token=settings.antenna_api_auth_token, pipeline_slug=pipeline, + processing_service_name=full_service_name, ) for job_id in jobs: logger.info( @@ -96,6 +107,7 @@ def _worker_loop(gpu_id: int, pipelines: list[str]): pipeline=pipeline, job_id=job_id, settings=settings, + processing_service_name=full_service_name, ) any_jobs = any_jobs or any_work_done except Exception as e: @@ -117,6 +129,7 @@ def _process_job( pipeline: str, job_id: int, settings: Settings, + processing_service_name: str, ) -> bool: """Run the worker to process images from the REST API queue. @@ -124,11 +137,16 @@ def _process_job( pipeline: Pipeline name to use for processing (e.g., moth_binary, panama_moths_2024) job_id: Job ID to process settings: Settings object with antenna_api_* configuration + processing_service_name: Name of the processing service Returns: True if any work was done, False otherwise """ did_work = False - loader = get_rest_dataloader(job_id=job_id, settings=settings) + loader = get_rest_dataloader( + job_id=job_id, + settings=settings, + processing_service_name=processing_service_name, + ) classifier = None detector = None @@ -317,6 +335,7 @@ def _process_job( settings.antenna_api_auth_token, job_id, batch_results, + processing_service_name, ) st, t = t("Finished posting results") diff --git a/trapdata/cli/worker.py b/trapdata/cli/worker.py index 19fb97a..f1b5782 100644 --- a/trapdata/cli/worker.py +++ b/trapdata/cli/worker.py @@ -16,7 +16,8 @@ def run( list[str] | None, typer.Option( "--pipeline", - help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. Defaults to all pipelines if not specified." + help="Pipeline to use for processing (e.g., moth_binary, panama_moths_2024). Can be specified multiple times. " + "Defaults to all pipelines if not specified.", ), ] = None, ): @@ -49,13 +50,6 @@ def run( @cli.command("register") def register( - name: Annotated[ - str, - typer.Argument( - help="Name for the processing service registration (e.g., 'AMI Data Companion on DRAC gpu-03'). " - "Hostname will be added automatically.", - ), - ], project: Annotated[ list[int] | None, typer.Option( @@ -70,11 +64,18 @@ def register( This command registers all available pipeline configurations with the Antenna platform for the specified projects (or all accessible projects if none specified). + The service name is read from the AMI_ANTENNA_SERVICE_NAME configuration setting. + Hostname will be added automatically to the service name. + Examples: - ami worker register "AMI Data Companion on DRAC gpu-03" --project 1 --project 2 - ami worker register "My Processing Service" # registers for all accessible projects + ami worker register --project 1 --project 2 + ami worker register # registers for all accessible projects """ from trapdata.antenna.registration import register_pipelines + from trapdata.settings import read_settings + settings = read_settings() project_ids = project if project else [] - register_pipelines(project_ids=project_ids, service_name=name) + register_pipelines( + project_ids=project_ids, service_name=settings.antenna_service_name + ) diff --git a/trapdata/settings.py b/trapdata/settings.py index 0020bd5..54324c3 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -40,6 +40,7 @@ class Settings(BaseSettings): # Antenna API worker settings antenna_api_base_url: str = "http://localhost:8000/api/v2" antenna_api_auth_token: str = "" + antenna_service_name: str = "AMI Data Companion" antenna_api_batch_size: int = 16 @pydantic.field_validator("image_base_path", "user_data_path") @@ -169,6 +170,12 @@ class Config: "kivy_type": "numeric", "kivy_section": "antenna", }, + "antenna_service_name": { + "title": "Antenna Service Name", + "description": "Name for the processing service registration (hostname will be added automatically)", + "kivy_type": "string", + "kivy_section": "antenna", + }, } @classmethod