Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .agents/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ Processing services are FastAPI applications that implement the AMI ML API contr
**Health Checks:**
- Cached status with 3 retries and exponential backoff (0s, 2s, 4s)
- Celery Beat task runs periodic checks (`ami.ml.tasks.check_processing_services_online`)
- Status stored in `ProcessingService.last_checked_live` boolean field
- Status stored in `ProcessingService.last_seen_live` boolean field
- Async/pull-mode services update status via `mark_seen()` when they register pipelines
- UI shows red/green indicator based on cached status

Location: `processing_services/` directory contains example implementations
Expand Down
5 changes: 3 additions & 2 deletions .agents/DATABASE_SCHEMA.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ erDiagram
bigint id PK
string name
string endpoint_url
boolean last_checked_live
float last_checked_latency
datetime last_seen
boolean last_seen_live
float last_seen_latency
}

ProjectPipelineConfig {
Expand Down
30 changes: 30 additions & 0 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@
logger = logging.getLogger(__name__)


def _mark_pipeline_pull_services_seen(job: "Job") -> None:
"""
Record a heartbeat for async (pull-mode) processing services linked to the job's pipeline.

Called on every task-fetch and result-submit request so that the worker's polling activity
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.

IMPORTANT: This marks ALL async services on the pipeline within this project as live, not just
the specific service that made the request. If multiple async services share the same pipeline
within a project, a single worker polling will keep all of them appearing online.
Once application-token auth is available (PR #1117), this should be scoped to the individual
calling service instead.
"""
import datetime

if not job.pipeline_id:
return
job.pipeline.processing_services.async_services().filter(projects=job.project_id).update(
last_seen=datetime.datetime.now(),
last_seen_live=True,
)


class JobFilterSet(filters.FilterSet):
"""Custom filterset to enable pipeline name filtering."""

Expand Down Expand Up @@ -245,6 +269,9 @@ def tasks(self, request, pk=None):
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Get tasks from NATS JetStream
from ami.ml.orchestration.nats_queue import TaskQueueManager

Expand Down Expand Up @@ -272,6 +299,9 @@ def result(self, request, pk=None):

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)

# Validate request data is a list
if isinstance(request.data, list):
results = request.data
Expand Down
26 changes: 26 additions & 0 deletions ami/ml/migrations/0027_rename_last_checked_to_last_seen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
("ml", "0026_make_processing_service_endpoint_url_nullable"),
]

operations = [
migrations.RenameField(
model_name="processingservice",
old_name="last_checked",
new_name="last_seen",
),
migrations.RenameField(
model_name="processingservice",
old_name="last_checked_live",
new_name="last_seen_live",
),
migrations.RenameField(
model_name="processingservice",
old_name="last_checked_latency",
new_name="last_seen_latency",
),
]
13 changes: 5 additions & 8 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ def online(self, project: Project) -> PipelineQuerySet:
"""
return self.filter(
processing_services__projects=project,
processing_services__last_checked_live=True,
processing_services__last_seen_live=True,
).distinct()


Expand Down Expand Up @@ -1048,7 +1048,7 @@ def collect_images(
def choose_processing_service_for_pipeline(
self, job_id: int | None, pipeline_name: str, project_id: int
) -> ProcessingService:
# @TODO use the cached `last_checked_latency` and a max age to avoid checking every time
# @TODO use the cached `last_seen_latency` and a max age to avoid checking every time

job = None
task_logger = logger
Expand All @@ -1070,13 +1070,10 @@ def choose_processing_service_for_pipeline(
processing_services_online = False

for processing_service in processing_services:
if processing_service.last_checked_live:
if processing_service.last_seen_live:
processing_services_online = True
if (
processing_service.last_checked_latency
and processing_service.last_checked_latency < lowest_latency
):
lowest_latency = processing_service.last_checked_latency
if processing_service.last_seen_latency and processing_service.last_seen_latency < lowest_latency:
lowest_latency = processing_service.last_seen_latency
# pick the processing service that has lowest latency
processing_service_lowest_latency = processing_service

Expand Down
97 changes: 75 additions & 22 deletions ami/ml/models/processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,34 @@

logger = logging.getLogger(__name__)

# Max age of last_seen before a pull-mode (no-endpoint) service is considered offline.
# Pull-mode workers poll every ~5s, so 60s gives 12x buffer for transient failures.
PROCESSING_SERVICE_LAST_SEEN_MAX = datetime.timedelta(seconds=60)

class ProcessingServiceManager(models.Manager.from_queryset(BaseQuerySet)):

class ProcessingServiceQuerySet(BaseQuerySet):
def async_services(self) -> "ProcessingServiceQuerySet":
"""
Filter to pull-mode (async) processing services — those with no endpoint URL.

These correspond to jobs with dispatch_mode=ASYNC_API. Instead of Antenna calling
out to them, they poll Antenna for tasks and push results back. Their liveness is
tracked via heartbeats from mark_seen() rather than active health checks.
"""
return self.filter(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))

def sync_services(self) -> "ProcessingServiceQuerySet":
"""
Filter to push-mode (sync) processing services — those with a configured endpoint URL.

These correspond to jobs with dispatch_mode=SYNC_API. Antenna actively calls their
/readyz and /process endpoints. Their liveness is tracked by the periodic
check_processing_services_online Celery task.
"""
return self.exclude(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))


class ProcessingServiceManager(models.Manager.from_queryset(ProcessingServiceQuerySet)):
"""Custom manager for ProcessingService to handle specific queries."""

def create(self, **kwargs) -> "ProcessingService":
Expand All @@ -41,12 +67,21 @@ class ProcessingService(BaseModel):
projects = models.ManyToManyField("main.Project", related_name="processing_services", blank=True)
endpoint_url = models.CharField(max_length=1024, null=True, blank=True)
pipelines = models.ManyToManyField("ml.Pipeline", related_name="processing_services", blank=True)
last_checked = models.DateTimeField(null=True)
last_checked_live = models.BooleanField(null=True)
last_checked_latency = models.FloatField(null=True)
last_seen = models.DateTimeField(null=True)
last_seen_live = models.BooleanField(null=True)
last_seen_latency = models.FloatField(null=True)

objects = ProcessingServiceManager()

@property
def is_async(self) -> bool:
"""
True if this is a pull-mode (async) service with no endpoint URL, corresponding to
jobs with dispatch_mode=ASYNC_API. False for push-mode services with a configured
endpoint, corresponding to jobs with dispatch_mode=SYNC_API.
"""
return not self.endpoint_url

def __str__(self):
endpoint_display = self.endpoint_url or "async"
return f'#{self.pk} "{self.name}" ({endpoint_display})'
Expand Down Expand Up @@ -139,6 +174,15 @@ def create_pipelines(
algorithms_created=algorithms_created,
)

def mark_seen(self, live: bool = True) -> None:
"""
Record that we heard from this processing service.
Used by async/pull-mode services that don't have an endpoint to check.
"""
self.last_seen = datetime.datetime.now()
self.last_seen_live = live
self.save(update_fields=["last_seen", "last_seen_live"])

def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
"""
Check the status of the processing service.
Expand All @@ -152,16 +196,25 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
Args:
timeout: Request timeout in seconds per attempt (default: 90s for serverless cold starts)
"""
# If no endpoint URL is configured, return a no-op response
if self.endpoint_url is None:
# If no endpoint URL is configured, derive status from last registration heartbeat
if not self.endpoint_url:
is_live = bool(
self.last_seen
and self.last_seen_live
and (datetime.datetime.now() - self.last_seen) < PROCESSING_SERVICE_LAST_SEEN_MAX
)
if not is_live and self.last_seen_live:
# Heartbeat has expired — mark stale
self.last_seen_live = False
self.save(update_fields=["last_seen_live"])
pipeline_names = list(self.pipelines.values_list("name", flat=True))
return ProcessingServiceStatusResponse(
timestamp=datetime.datetime.now(),
request_successful=False,
server_live=None,
pipelines_online=[],
timestamp=self.last_seen or datetime.datetime.now(),
request_successful=is_live,
server_live=is_live,
pipelines_online=pipeline_names,
pipeline_configs=[],
endpoint_url=self.endpoint_url,
error="No endpoint URL configured - service operates in pull mode",
endpoint_url=None,
latency=0.0,
)

Expand All @@ -171,7 +224,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
pipeline_configs = []
pipelines_online = []
timestamp = datetime.datetime.now()
self.last_checked = timestamp
self.last_seen = timestamp
resp = None

# Create session with retry logic for connection errors and timeouts
Expand All @@ -184,23 +237,23 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
try:
resp = session.get(ready_check_url, timeout=timeout)
resp.raise_for_status()
self.last_checked_live = True
self.last_seen_live = True
except requests.exceptions.RequestException as e:
error = f"Error connecting to {ready_check_url}: {e}"
logger.error(error)
self.last_checked_live = False
self.last_seen_live = False
finally:
latency = time.time() - start_time
self.last_checked_latency = latency
self.last_seen_latency = latency
self.save(
update_fields=[
"last_checked",
"last_checked_live",
"last_checked_latency",
"last_seen",
"last_seen_live",
"last_seen_latency",
]
)

if self.last_checked_live:
if self.last_seen_live:
# The specific pipeline statuses are not required for the status response
# but the intention is to show which ones are loaded into memory and ready to use.
# @TODO: this may be overkill, but it is displayed in the UI now.
Expand All @@ -214,7 +267,7 @@ def get_status(self, timeout=90) -> ProcessingServiceStatusResponse:
response = ProcessingServiceStatusResponse(
timestamp=timestamp,
request_successful=resp.ok if resp else False,
server_live=self.last_checked_live,
server_live=self.last_seen_live,
pipelines_online=pipelines_online,
pipeline_configs=pipeline_configs,
endpoint_url=self.endpoint_url,
Expand All @@ -229,7 +282,7 @@ def get_pipeline_configs(self, timeout=6):
Get the pipeline configurations from the processing service.
This can be a long response as it includes the full category map for each algorithm.
"""
if self.endpoint_url is None:
if not self.endpoint_url:
return []

info_url = urljoin(self.endpoint_url, "info")
Expand Down
13 changes: 9 additions & 4 deletions ami/ml/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ class Meta:


class ProcessingServiceNestedSerializer(DefaultSerializer):
is_async = serializers.BooleanField(read_only=True)

class Meta:
model = ProcessingService
fields = [
"name",
"id",
"details",
"endpoint_url",
"last_checked",
"last_checked_live",
"is_async",
"last_seen",
"last_seen_live",
"created_at",
"updated_at",
]
Expand Down Expand Up @@ -134,6 +137,7 @@ class Meta:

class ProcessingServiceSerializer(DefaultSerializer):
pipelines = PipelineNestedSerializer(many=True, read_only=True)
is_async = serializers.BooleanField(read_only=True)
project = serializers.PrimaryKeyRelatedField(
write_only=True,
queryset=Project.objects.all(),
Expand All @@ -150,11 +154,12 @@ class Meta:
"projects",
"project",
"endpoint_url",
"is_async",
"pipelines",
"created_at",
"updated_at",
"last_checked",
"last_checked_live",
"last_seen",
"last_seen_live",
]

def create(self, validated_data):
Expand Down
Loading