Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,7 @@ def run(cls, job: "Job"):
# End image collection stage
job.save()

if job.project.feature_flags.async_pipeline_workers:
job.dispatch_mode = JobDispatchMode.ASYNC_API
job.save(update_fields=["dispatch_mode"])
if job.dispatch_mode == JobDispatchMode.ASYNC_API:
queued = queue_images_to_nats(job, images)
if not queued:
job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk)
Expand All @@ -473,8 +471,6 @@ def run(cls, job: "Job"):
job.save()
return
else:
job.dispatch_mode = JobDispatchMode.SYNC_API
job.save(update_fields=["dispatch_mode"])
cls.process_images(job, images)

@classmethod
Expand Down Expand Up @@ -919,6 +915,15 @@ def setup(self, save=True):
self.progress.add_stage_param(delay_stage.key, "Mood", "😴")

if self.pipeline:
# Set dispatch mode based on project feature flags at creation time
# so the UI can show the correct mode before the job runs.
# Only override if still at the default (INTERNAL), to allow explicit overrides.
if self.dispatch_mode == JobDispatchMode.INTERNAL:
if self.project and self.project.feature_flags.async_pipeline_workers:
self.dispatch_mode = JobDispatchMode.ASYNC_API
else:
self.dispatch_mode = JobDispatchMode.SYNC_API

collect_stage = self.progress.add_stage("Collect")
self.progress.add_stage_param(collect_stage.key, "Total Images", "")

Expand Down
69 changes: 65 additions & 4 deletions ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,36 @@ def test_filter_by_pipeline_slug(self):
self.assertEqual(data["count"], 1)
self.assertEqual(data["results"][0]["id"], job_with_pipeline.pk)

def test_filter_by_pipeline_slug_in(self):
"""Test filtering jobs by pipeline__slug__in (multiple slugs)."""
pipeline_a = self._create_pipeline("Pipeline A", "pipeline-a")
pipeline_b = Pipeline.objects.create(name="Pipeline B", slug="pipeline-b", description="B")
pipeline_b.projects.add(self.project)
pipeline_c = Pipeline.objects.create(name="Pipeline C", slug="pipeline-c", description="C")
pipeline_c.projects.add(self.project)

job_a = self._create_ml_job("Job A", pipeline_a)
job_b = self._create_ml_job("Job B", pipeline_b)
job_c = self._create_ml_job("Job C", pipeline_c)

self.client.force_authenticate(user=self.user)

# Filter for two of the three pipelines
jobs_list_url = reverse_with_params(
"api:job-list",
params={"project_id": self.project.pk, "pipeline__slug__in": "pipeline-a,pipeline-b"},
)
resp = self.client.get(jobs_list_url)

self.assertEqual(resp.status_code, 200)
data = resp.json()
returned_ids = {job["id"] for job in data["results"]}
self.assertIn(job_a.pk, returned_ids)
self.assertIn(job_b.pk, returned_ids)
self.assertNotIn(job_c.pk, returned_ids)
# Original setUp job (no pipeline) should also be excluded
self.assertNotIn(self.job.pk, returned_ids)

def test_search_jobs(self):
"""Test searching jobs by name and pipeline name."""
pipeline = self._create_pipeline("SearchablePipeline", "searchable-pipeline")
Expand Down Expand Up @@ -571,13 +601,11 @@ def test_dispatch_mode_filtering(self):
dispatch_mode=JobDispatchMode.ASYNC_API,
)

# Create a job with default dispatch_mode (should be "internal")
# Create a non-ML job without a pipeline (dispatch_mode stays "internal")
internal_job = Job.objects.create(
job_type_key=MLJob.key,
job_type_key="data_storage_sync",
project=self.project,
name="Internal Job",
pipeline=self.pipeline,
source_image_collection=self.source_image_collection,
)

self.client.force_authenticate(user=self.user)
Expand Down Expand Up @@ -614,6 +642,39 @@ def test_dispatch_mode_filtering(self):
expected_ids = {sync_job.pk, async_job.pk, internal_job.pk}
self.assertEqual(returned_ids, expected_ids)

def test_ml_job_dispatch_mode_set_on_creation(self):
"""Test that ML jobs get dispatch_mode set based on project feature flags at creation time."""
# Without async flag, ML job should default to sync_api
sync_job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Auto Sync Job",
pipeline=self.pipeline,
source_image_collection=self.source_image_collection,
)
self.assertEqual(sync_job.dispatch_mode, JobDispatchMode.SYNC_API)

# Enable async flag on project
self.project.feature_flags.async_pipeline_workers = True
self.project.save()

async_job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="Auto Async Job",
pipeline=self.pipeline,
source_image_collection=self.source_image_collection,
)
self.assertEqual(async_job.dispatch_mode, JobDispatchMode.ASYNC_API)

# Non-pipeline job should stay internal regardless of feature flag
internal_job = Job.objects.create(
job_type_key="data_storage_sync",
project=self.project,
name="Internal Job",
)
self.assertEqual(internal_job.dispatch_mode, JobDispatchMode.INTERNAL)

def test_tasks_endpoint_rejects_non_async_jobs(self):
"""Test that /tasks endpoint returns 400 for non-async_api jobs."""
from ami.base.serializers import reverse_with_params
Expand Down
30 changes: 18 additions & 12 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging

import nats.errors
import pydantic
from asgiref.sync import async_to_sync
from django.db.models import Q
Expand Down Expand Up @@ -32,6 +34,7 @@ class JobFilterSet(filters.FilterSet):
"""Custom filterset to enable pipeline name filtering."""

pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact")
pipeline__slug__in = filters.BaseInFilter(field_name="pipeline__slug", lookup_expr="in")

class Meta:
model = Job
Expand All @@ -55,11 +58,12 @@ def filter_queryset(self, request, queryset, view):
incomplete_only = url_boolean_param(request, "incomplete_only", default=False)
# Filter to incomplete jobs if requested (checks "results" stage status)
if incomplete_only:
# Create filters for each final state to exclude
# Exclude jobs with a terminal top-level status
queryset = queryset.exclude(status__in=JobState.final_states())

# Also exclude jobs where the "results" stage has a final state status
final_states = JobState.final_states()
exclude_conditions = Q()

# Exclude jobs where the "results" stage has a final state status
for state in final_states:
# JSON path query to check if results stage status is in final states
# @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere
Expand Down Expand Up @@ -233,6 +237,10 @@ def tasks(self, request, pk=None):
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
raise ValidationError("Only async_api jobs have fetchable tasks")

# Don't fetch tasks from completed/failed/revoked jobs
if job.status in JobState.final_states():
return Response({"tasks": []})

# Validate that the job has a pipeline
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")
Expand All @@ -241,16 +249,14 @@ def tasks(self, request, pk=None):
from ami.ml.orchestration.nats_queue import TaskQueueManager

async def get_tasks():
tasks = []
async with TaskQueueManager() as manager:
for _ in range(batch):
task = await manager.reserve_task(job.pk, timeout=0.1)
if task:
tasks.append(task.dict())
return tasks

# Use async_to_sync to properly handle the async call
tasks = async_to_sync(get_tasks)()
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)]

try:
tasks = async_to_sync(get_tasks)()
except (asyncio.TimeoutError, OSError, nats.errors.Error) as e:
logger.warning("NATS unavailable while fetching tasks for job %s: %s", job.pk, e)
return Response({"error": "Task queue temporarily unavailable"}, status=503)

return Response({"tasks": tasks})

Expand Down
Loading