Skip to content

Commit 77dd024

Browse files
mihowclaude
andcommitted
feat: async PS liveness tracking and ProcessingServiceQuerySet API
Add structured queryset methods and a heartbeat mechanism so async (pull-mode) processing services stay in sync with their actual liveness. ProcessingService: - New ProcessingServiceQuerySet with async_services() / sync_services() methods — single canonical filter for endpoint_url null-or-empty, used everywhere instead of ad-hoc Q expressions - is_async property (derived from endpoint_url, no DB column) - Docstrings reference Job.dispatch_mode ASYNC_API / SYNC_API for context Liveness tracking: - PROCESSING_SERVICE_LAST_SEEN_MAX = 60s constant (12× the worker's 5s poll interval) — async services are considered offline after this - check_processing_services_online task now handles both modes: sync → active /readyz poll; async → bulk mark stale via async_services() - _mark_pipeline_pull_services_seen() helper in jobs/views.py: single bulk UPDATE via job.pipeline.processing_services.async_services(), called at the top of both /jobs/{id}/tasks/ and /jobs/{id}/result/ so every worker poll cycle refreshes last_seen without needing a separate registration Async job cleanup (from carlosg/redisatomic): - Rename _cleanup_job_if_needed → cleanup_async_job_if_needed and export it so Job.cancel() can call it directly without a local import - JobLogHandler: refresh_from_db before appending to avoid last-writer- wins race across concurrent worker processes - Job.logger: update existing handler's job reference instead of always adding a new handler (process-level singleton leak fix) Co-Authored-By: Claude <[email protected]>
1 parent 932824b commit 77dd024

5 files changed

Lines changed: 125 additions & 28 deletions

File tree

ami/jobs/models.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from ami.base.models import BaseModel
1717
from ami.base.schemas import ConfigurableStage, ConfigurableStageParam
18-
from ami.jobs.tasks import run_job
18+
from ami.jobs.tasks import cleanup_async_job_if_needed, run_job
1919
from ami.main.models import Deployment, Project, SourceImage, SourceImageCollection
2020
from ami.ml.models import Pipeline
2121
from ami.ml.post_processing.registry import get_postprocessing_task
@@ -336,7 +336,11 @@ def emit(self, record: logging.LogRecord):
336336
# Log to the current app logger
337337
logger.log(record.levelno, self.format(record))
338338

339-
# Write to the logs field on the job instance
339+
# Write to the logs field on the job instance.
340+
# Refresh from DB first to reduce the window for concurrent overwrites — each
341+
# worker holds its own stale in-memory copy of `logs`, so without a refresh the
342+
# last writer always wins and earlier entries are silently dropped.
343+
self.job.refresh_from_db(fields=["logs"])
340344
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
341345
msg = f"[{timestamp}] {record.levelname} {self.format(record)}"
342346
if msg not in self.job.logs.stdout:
@@ -355,7 +359,6 @@ def emit(self, record: logging.LogRecord):
355359
self.job.save(update_fields=["logs"], update_progress=False)
356360
except Exception as e:
357361
logger.error(f"Failed to save logs for job #{self.job.pk}: {e}")
358-
pass
359362

360363

361364
@dataclass
@@ -975,24 +978,20 @@ def cancel(self):
975978
and transition through CANCELING → REVOKED. For other jobs,
976979
revoke the Celery task.
977980
"""
978-
from ami.jobs.tasks import _cleanup_job_if_needed
979-
980981
self.status = JobState.CANCELING
981982
self.save()
982983

983-
if self.dispatch_mode == JobDispatchMode.ASYNC_API:
984-
# For async jobs, the Celery task has already completed (it just queued
985-
# images to NATS). Clean up NATS/Redis resources to stop task delivery,
986-
# then mark as REVOKED.
987-
_cleanup_job_if_needed(self)
988-
self.status = JobState.REVOKED
989-
self.finished_at = datetime.datetime.now()
990-
self.save()
991-
elif self.task_id:
984+
cleanup_async_job_if_needed(self)
985+
if self.task_id:
992986
task = run_job.AsyncResult(self.task_id)
993987
if task:
994988
task.revoke(terminate=True)
995989
self.save()
990+
if self.dispatch_mode == JobDispatchMode.ASYNC_API:
991+
# For async jobs we need to set the status to revoked here since the task already
992+
# finished (it only queues the images).
993+
self.status = JobState.REVOKED
994+
self.save()
996995
else:
997996
self.status = JobState.REVOKED
998997
self.save()
@@ -1102,11 +1101,15 @@ def get_default_progress(cls) -> JobProgress:
11021101
def logger(self) -> logging.Logger:
11031102
_logger = logging.getLogger(f"ami.jobs.{self.pk}")
11041103

1105-
# Only add JobLogHandler if not already present
1106-
if not any(isinstance(h, JobLogHandler) for h in _logger.handlers):
1107-
# Also log output to a field on thie model instance
1104+
# Update or add JobLogHandler, always pointing to the current instance.
1105+
# The logger is a process-level singleton so its handler may reference a stale
1106+
# job instance from a previous task execution in this worker process.
1107+
handler = next((h for h in _logger.handlers if isinstance(h, JobLogHandler)), None)
1108+
if handler is None:
11081109
logger.info("Adding JobLogHandler to logger for job %s", self.pk)
11091110
_logger.addHandler(JobLogHandler(self))
1111+
else:
1112+
handler.job = self
11101113
_logger.propagate = False
11111114
return _logger
11121115

ami/jobs/tasks.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,26 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
178178
job.logger.error(error)
179179

180180

181+
def _fail_job(job_id: int, reason: str) -> None:
182+
from ami.jobs.models import Job, JobState
183+
from ami.ml.orchestration.jobs import cleanup_async_job_resources
184+
185+
try:
186+
with transaction.atomic():
187+
job = Job.objects.select_for_update().get(pk=job_id)
188+
if job.status in (JobState.CANCELING, *JobState.final_states()):
189+
return
190+
job.status = JobState.FAILURE
191+
job.finished_at = datetime.datetime.now()
192+
job.save(update_fields=["status", "finished_at"])
193+
194+
job.logger.error(f"Job {job_id} marked as FAILURE: {reason}")
195+
cleanup_async_job_resources(job.pk, job.logger)
196+
except Job.DoesNotExist:
197+
logger.error(f"Cannot fail job {job_id}: not found")
198+
cleanup_async_job_resources(job_id, logger)
199+
200+
181201
def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
182202
try:
183203

@@ -293,10 +313,10 @@ def _update_job_progress(
293313
# Clean up async resources for completed jobs that use NATS/Redis
294314
if job.progress.is_complete():
295315
job = Job.objects.get(pk=job_id) # Re-fetch outside transaction
296-
_cleanup_job_if_needed(job)
316+
cleanup_async_job_if_needed(job)
297317

298318

299-
def _cleanup_job_if_needed(job) -> None:
319+
def cleanup_async_job_if_needed(job) -> None:
300320
"""
301321
Clean up async resources (NATS/Redis) if this job uses them.
302322
@@ -312,7 +332,7 @@ def _cleanup_job_if_needed(job) -> None:
312332
# import here to avoid circular imports
313333
from ami.ml.orchestration.jobs import cleanup_async_job_resources
314334

315-
cleanup_async_job_resources(job)
335+
cleanup_async_job_resources(job.pk, job.logger)
316336

317337

318338
@task_prerun.connect(sender=run_job)
@@ -351,7 +371,7 @@ def update_job_status(sender, task_id, task, state: str, retval=None, **kwargs):
351371

352372
# Clean up async resources for revoked jobs
353373
if state == JobState.REVOKED:
354-
_cleanup_job_if_needed(job)
374+
cleanup_async_job_if_needed(job)
355375

356376

357377
@task_failure.connect(sender=run_job, retry=False)
@@ -366,7 +386,7 @@ def update_job_failure(sender, task_id, exception, *args, **kwargs):
366386
job.save()
367387

368388
# Clean up async resources for failed jobs
369-
_cleanup_job_if_needed(job)
389+
cleanup_async_job_if_needed(job)
370390

371391

372392
def log_time(start: float = 0, msg: str | None = None) -> tuple[float, Callable]:

ami/jobs/views.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@
3030
logger = logging.getLogger(__name__)
3131

3232

33+
def _mark_pipeline_pull_services_seen(job: "Job") -> None:
34+
"""
35+
Record a heartbeat for all async (pull-mode) processing services linked to the job's pipeline.
36+
37+
Called on every task-fetch and result-submit request so that the worker's polling activity
38+
keeps last_seen/last_seen_live current. The periodic check_processing_services_online task
39+
will mark services offline if this heartbeat stops arriving within PROCESSING_SERVICE_LAST_SEEN_MAX.
40+
41+
Note: caller identity is not verified here — any authenticated token can hit these endpoints.
42+
A future application-token scheme (see PR #1117) will allow tying requests to a specific
43+
processing service so the heartbeat can be scoped more precisely.
44+
"""
45+
import datetime
46+
47+
if not job.pipeline_id:
48+
return
49+
job.pipeline.processing_services.async_services().update(
50+
last_seen=datetime.datetime.now(),
51+
last_seen_live=True,
52+
)
53+
54+
3355
class JobFilterSet(filters.FilterSet):
3456
"""Custom filterset to enable pipeline name filtering."""
3557

@@ -245,6 +267,9 @@ def tasks(self, request, pk=None):
245267
if not job.pipeline:
246268
raise ValidationError("This job does not have a pipeline configured")
247269

270+
# Record heartbeat for async processing services on this pipeline
271+
_mark_pipeline_pull_services_seen(job)
272+
248273
# Get tasks from NATS JetStream
249274
from ami.ml.orchestration.nats_queue import TaskQueueManager
250275

@@ -272,6 +297,9 @@ def result(self, request, pk=None):
272297

273298
job = self.get_object()
274299

300+
# Record heartbeat for async processing services on this pipeline
301+
_mark_pipeline_pull_services_seen(job)
302+
275303
# Validate request data is a list
276304
if isinstance(request.data, list):
277305
results = request.data

ami/ml/models/processing_service.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,29 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
class ProcessingServiceManager(models.Manager.from_queryset(BaseQuerySet)):
26+
class ProcessingServiceQuerySet(BaseQuerySet):
27+
def async_services(self) -> "ProcessingServiceQuerySet":
28+
"""
29+
Filter to pull-mode (async) processing services — those with no endpoint URL.
30+
31+
These correspond to jobs with dispatch_mode=ASYNC_API. Instead of Antenna calling
32+
out to them, they poll Antenna for tasks and push results back. Their liveness is
33+
tracked via heartbeats from mark_seen() rather than active health checks.
34+
"""
35+
return self.filter(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))
36+
37+
def sync_services(self) -> "ProcessingServiceQuerySet":
38+
"""
39+
Filter to push-mode (sync) processing services — those with a configured endpoint URL.
40+
41+
These correspond to jobs with dispatch_mode=SYNC_API. Antenna actively calls their
42+
/readyz and /process endpoints. Their liveness is tracked by the periodic
43+
check_processing_services_online Celery task.
44+
"""
45+
return self.exclude(models.Q(endpoint_url__isnull=True) | models.Q(endpoint_url__exact=""))
46+
47+
48+
class ProcessingServiceManager(models.Manager.from_queryset(ProcessingServiceQuerySet)):
2749
"""Custom manager for ProcessingService to handle specific queries."""
2850

2951
def create(self, **kwargs) -> "ProcessingService":
@@ -47,6 +69,15 @@ class ProcessingService(BaseModel):
4769

4870
objects = ProcessingServiceManager()
4971

72+
@property
73+
def is_async(self) -> bool:
74+
"""
75+
True if this is a pull-mode (async) service with no endpoint URL, corresponding to
76+
jobs with dispatch_mode=ASYNC_API. False for push-mode services with a configured
77+
endpoint, corresponding to jobs with dispatch_mode=SYNC_API.
78+
"""
79+
return not self.endpoint_url
80+
5081
def __str__(self):
5182
endpoint_display = self.endpoint_url or "async"
5283
return f'#{self.pk} "{self.name}" ({endpoint_display})'

ami/ml/tasks.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,22 @@ def remove_duplicate_classifications(project_id: int | None = None, dry_run: boo
9898
@celery_app.task(soft_time_limit=10, time_limit=20)
9999
def check_processing_services_online():
100100
"""
101-
Check the status of all v1 synchronous processing services and update the last_seen/last_seen_live fields.
102-
Asynchronous (pull-mode) services are updated via mark_seen() when they register pipelines.
101+
Check the status of all processing services and update last_seen/last_seen_live fields.
102+
103+
- Sync services (dispatch_mode=SYNC_API, endpoint URL set): actively polled via /readyz.
104+
- Async services (dispatch_mode=ASYNC_API, no endpoint URL): heartbeat is updated by
105+
mark_seen() on registration and by _mark_pipeline_pull_services_seen() on task polling.
106+
This task marks them offline if last_seen has exceeded PROCESSING_SERVICE_LAST_SEEN_MAX.
103107
104108
@TODO make this async to check all services in parallel
105109
"""
106-
from ami.ml.models import ProcessingService
110+
import datetime
111+
112+
from ami.ml.models import PROCESSING_SERVICE_LAST_SEEN_MAX, ProcessingService
107113

108-
logger.info("Checking which synchronous processing services are online.")
114+
logger.info("Checking which processing services are online.")
109115

110-
services = ProcessingService.objects.exclude(endpoint_url__isnull=True).exclude(endpoint_url__exact="").all()
116+
services = ProcessingService.objects.sync_services()
111117

112118
for service in services:
113119
logger.info(f"Checking service {service}")
@@ -117,3 +123,12 @@ def check_processing_services_online():
117123
except Exception as e:
118124
logger.error(f"Error checking service {service}: {e}")
119125
continue
126+
127+
stale_cutoff = datetime.datetime.now() - PROCESSING_SERVICE_LAST_SEEN_MAX
128+
stale = ProcessingService.objects.async_services().filter(last_seen_live=True, last_seen__lt=stale_cutoff)
129+
count = stale.count()
130+
if count:
131+
logger.info(
132+
f"Marking {count} async service(s) offline (no heartbeat within {PROCESSING_SERVICE_LAST_SEEN_MAX})."
133+
)
134+
stale.update(last_seen_live=False)

0 commit comments

Comments
 (0)