Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .agents/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ Processing services are FastAPI applications that implement the AMI ML API contr
**Health Checks:**
- Cached status with 3 retries and exponential backoff (0s, 2s, 4s)
- Celery Beat task runs periodic checks (`ami.ml.tasks.check_processing_services_online`)
- Status stored in `ProcessingService.last_checked_live` boolean field
- Status stored in `ProcessingService.last_seen_live` boolean field
- Async/pull-mode services update status via `mark_seen()` when they register pipelines
- UI shows red/green indicator based on cached status

Location: `processing_services/` directory contains example implementations
Expand Down
5 changes: 3 additions & 2 deletions .agents/DATABASE_SCHEMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ erDiagram
bigint id PK
string name
string endpoint_url
boolean last_checked_live
float last_checked_latency
datetime last_seen
boolean last_seen_live
float last_seen_latency
}

ProjectPipelineConfig {
Expand Down
28 changes: 28 additions & 0 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@
logger = logging.getLogger(__name__)


def _mark_pipeline_pull_services_seen(job: "Job") -> None:
"""
Record a heartbeat for all async (pull-mode) processing services linked to the job's pipeline.

Called on every task-fetch and result-submit request so that the worker's polling activity
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.

Note: caller identity is not verified here — any authenticated token can hit these endpoints.
A future application-token scheme (see PR #1117) will allow tying requests to a specific
processing service so the heartbeat can be scoped more precisely.
"""
import datetime

if not job.pipeline_id:
return
job.pipeline.processing_services.async_services().update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)


class JobFilterSet(filters.FilterSet):
"""Custom filterset to enable pipeline name filtering."""

Expand Down Expand Up @@ -245,6 +267,9 @@ def tasks(self, request, pk=None):
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Get tasks from NATS JetStream
from ami.ml.orchestration.nats_queue import TaskQueueManager

Expand Down Expand Up @@ -272,6 +297,9 @@ def result(self, request, pk=None):

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
Expand Down
26 changes: 26 additions & 0 deletions ami/ml/migrations/0027_rename_last_checked_to_last_seen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
("ml", "0026_make_processing_service_endpoint_url_nullable"),
]

operations = [
migrations.RenameField(
model_name="processingservice",
old_name="last_checked",
new_name="last_seen",
),
migrations.RenameField(
model_name="processingservice",
old_name="last_checked_live",
new_name="last_seen_live",
),
migrations.RenameField(
model_name="processingservice",
old_name="last_checked_latency",
new_name="last_seen_latency",
),
]
13 changes: 5 additions & 8 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def online(self, project: Project) -> PipelineQuerySet:
"""
return self.filter(
processing_services__projects=project,
processing_services__last_checked_live=True,
processing_services__last_seen_live=True,
).distinct()


Expand Down Expand Up @@ -1048,7 +1048,7 @@ def collect_images(
def choose_processing_service_for_pipeline(
self, job_id: int | None, pipeline_name: str, project_id: int
) -> ProcessingService:
# @TODO use the cached `last_checked_latency` and a max age to avoid checking every time
# @TODO use the cached `last_seen_latency` and a max age to avoid checking every time

job = None
task_logger = logger
Expand All @@ -1070,13 +1070,10 @@ def choose_processing_service_for_pipeline(
processing_services_online = False

for processing_service in processing_services:
if processing_service.last_checked_live:
if processing_service.last_seen_live:
processing_services_online = True
if (
processing_service.last_checked_latency
and processing_service.last_checked_latency < lowest_latency
):
lowest_latency = processing_service.last_checked_latency
if processing_service.last_seen_latency and processing_service.last_seen_latency < lowest_latency:
lowest_latency = processing_service.last_seen_latency
# pick the processing service that has lowest latency
processing_service_lowest_latency = processing_service

Expand Down
97 changes: 75 additions & 22 deletions ami/ml/models/processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,34 @@

logger = logging.getLogger(__name__)

# Max age of last_seen before a pull-mode (no-endpoint) service is considered offline.
# Pull-mode workers poll every ~5s, so 60s gives 12x buffer for transient failures.
PROCESSING_SERVICE_LAST_SEEN_MAX = datetime.timedelta(seconds=60)

class ProcessingServiceManager(models.Manager.from_queryset(BaseQuerySet)):

class ProcessingServiceQuerySet(BaseQuerySet):
def async_services(self) -> "ProcessingServiceQuerySet":
"""
Filter to pull-mode (async) processing services — those with no endpoint URL.

These correspond to jobs with dispatch_mode=ASYNC_API. Instead of Antenna calling
out to them, they poll Antenna for tasks and push results back. Their liveness is
tracked via heartbeats from mark_seen() rather than active health checks.
"""
return self.filter(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))

def sync_services(self) -> "ProcessingServiceQuerySet":
"""
Filter to push-mode (sync) processing services — those with a configured endpoint URL.

These correspond to jobs with dispatch_mode=SYNC_API. Antenna actively calls their
/readyz and /process endpoints. Their liveness is tracked by the periodic
check_processing_services_online Celery task.
"""
return self.exclude(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))


class ProcessingServiceManager(models.Manager.from_queryset(ProcessingServiceQuerySet)):
"""Custom manager for ProcessingService to handle specific queries."""

def create(self, **kwargs) -> "ProcessingService":
Expand All @@ -41,12 +67,21 @@ class ProcessingService(BaseModel):
projects = models.ManyToManyField("main.Project", related_name="processing_services", blank=True)
endpoint_url = models.CharField(max_length=1024, null=True, blank=True)
pipelines = models.ManyToManyField("ml.Pipeline", related_name="processing_services", blank=True)
last_checked = models.DateTimeField(null=True)
last_checked_live = models.BooleanField(null=True)
last_checked_latency = models.FloatField(null=True)
last_seen = models.DateTimeField(null=True)
last_seen_live = models.BooleanField(null=True)
last_seen_latency = models.FloatField(null=True)

objects = ProcessingServiceManager()

@property
def is_async(self) -> bool:
"""
True if this is a pull-mode (async) service with no endpoint URL, corresponding to
jobs with dispatch_mode=ASYNC_API. False for push-mode services with a configured
endpoint, corresponding to jobs with dispatch_mode=SYNC_API.
"""
return not self.endpoint_url

def __str__(self):
endpoint_display = self.endpoint_url or "async"
return f'#{self.pk} "{self.name}" ({endpoint_display})'
Expand Down Expand Up @@ -139,6 +174,15 @@ def create_pipelines(
algorithms_created=algorithms_created,
)

def mark_seen(self, live: bool = True) -> None:
"""
Record that we heard from this processing service.
Used by async/pull-mode services that don't have an endpoint to check.
"""
self.last_seen = datetime.datetime.now()
self.last_seen_live = live
self.save(update_fields=["last_seen", "last_seen_live"])

def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
"""
Check the status of the processing service.
Expand All @@ -152,16 +196,25 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
Args:
timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts)
"""
# If no endpoint URL is configured, return a no-op response
if self.endpoint_url is None:
# If no endpoint URL is configured, derive status from last registration heartbeat
if not self.endpoint_url:
is_live = bool(
self.last_seen
and self.last_seen_live
and (datetime.datetime.now() - self.last_seen) < PROCESSING_SERVICE_LAST_SEEN_MAX
)
if not is_live and self.last_seen_live:
# Heartbeat has expired — mark stale
self.last_seen_live = False
self.save(update_fields=["last_seen_live"])
pipeline_names = list(self.pipelines.values_list("name", flat=True))
return ProcessingServiceStatusResponse(
timestamp=datetime.datetime.now(),
request_successful=False,
server_live=None,
pipelines_online=[],
timestamp=self.last_seen or datetime.datetime.now(),
request_successful=is_live,
server_live=is_live,
pipelines_online=pipeline_names,
pipeline_configs=[],
endpoint_url=self.endpoint_url,
error="No endpoint URL configured - service operates in pull mode",
endpoint_url=None,
latency=0.0,
)

Expand All @@ -171,7 +224,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
pipeline_configs = []
pipelines_online = []
timestamp = datetime.datetime.now()
self.last_checked = timestamp
self.last_seen = timestamp
resp = None

# Create session with retry logic for connection errors and timeouts
Expand All @@ -184,23 +237,23 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
try:
resp = session.get(ready_check_url, timeout=timeout)
resp.raise_for_status()
self.last_checked_live = True
self.last_seen_live = True
except requests.exceptions.RequestException as e:
error = f"Error connecting to {ready_check_url}: {e}"
logger.error(error)
self.last_checked_live = False
self.last_seen_live = False
finally:
latency = time.time() - start_time
self.last_checked_latency = latency
self.last_seen_latency = latency
self.save(
update_fields=[
"last_checked",
"last_checked_live",
"last_checked_latency",
"last_seen",
"last_seen_live",
"last_seen_latency",
]
)

if self.last_checked_live:
if self.last_seen_live:
# The specific pipeline statuses are not required for the status response
# but the intention is to show which ones are loaded into memory and ready to use.
# @TODO: this may be overkill, but it is displayed in the UI now.
Expand All @@ -214,7 +267,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
response = ProcessingServiceStatusResponse(
timestamp=timestamp,
request_successful=resp.ok if resp else False,
server_live=self.last_checked_live,
server_live=self.last_seen_live,
pipelines_online=pipelines_online,
pipeline_configs=pipeline_configs,
endpoint_url=self.endpoint_url,
Expand All @@ -229,7 +282,7 @@ def get_pipeline_configs(self, timeout=6):
Get the pipeline configurations from the processing service.
This can be a long response as it includes the full category map for each algorithm.
"""
if self.endpoint_url is None:
if not self.endpoint_url:
return []

info_url = urljoin(self.endpoint_url, "info")
Expand Down
8 changes: 4 additions & 4 deletions ami/ml/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class Meta:
"id",
"details",
"endpoint_url",
"last_checked",
"last_checked_live",
"last_seen",
"last_seen_live",
"created_at",
"updated_at",
]
Expand Down Expand Up @@ -153,8 +153,8 @@ class Meta:
"pipelines",
"created_at",
"updated_at",
"last_checked",
"last_checked_live",
"last_seen",
"last_seen_live",
]

def create(self, validated_data):
Expand Down
27 changes: 20 additions & 7 deletions ami/ml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,35 @@ def remove_duplicate_classifications(project_id: int | None = None, dry_run: boo
@celery_app.task(soft_time_limit=10, time_limit=20)
def check_processing_services_online():
"""
Check the status of all v1 synchronous processing services and update the last_seen field.
We will update last_seen for asynchronous services when we receive a request from them.
Check the status of all processing services and update last_seen/last_seen_live fields.

- Sync services (dispatch_mode=SYNC_API, endpoint URL set): actively polled via /readyz.
- Async services (dispatch_mode=ASYNC_API, no endpoint URL): heartbeat is updated by
mark_seen() on registration and by _mark_pipeline_pull_services_seen() on task polling.
This task marks them offline if last_seen has exceeded PROCESSING_SERVICE_LAST_SEEN_MAX.

@TODO make this async to check all services in parallel
"""
from ami.ml.models import ProcessingService
import datetime

logger.info("Checking which synchronous processing services are online.")
from ami.ml.models import PROCESSING_SERVICE_LAST_SEEN_MAX, ProcessingService

services = ProcessingService.objects.exclude(endpoint_url__isnull=True).exclude(endpoint_url__exact="").all()
logger.info("Checking which processing services are online.")

for service in services:
logger.info(f"Checking service {service}")
for service in ProcessingService.objects.sync_services():
logger.info(f"Checking push-mode service {service}")
try:
status_response = service.get_status()
logger.debug(status_response)
except Exception as e:
logger.error(f"Error checking service {service}: {e}")
continue

stale_cutoff = datetime.datetime.now() - PROCESSING_SERVICE_LAST_SEEN_MAX
stale = ProcessingService.objects.async_services().filter(last_seen_live=True, last_seen__lt=stale_cutoff)
count = stale.count()
if count:
logger.info(
f"Marking {count} async service(s) offline (no heartbeat within {PROCESSING_SERVICE_LAST_SEEN_MAX})."
)
stale.update(last_seen_live=False)
Loading
Loading