Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ def filter_queryset(self, request, queryset, view):
incomplete_only = url_boolean_param(request, "incomplete_only", default=False)
# Filter to incomplete jobs if requested (checks "results" stage status)
if incomplete_only:
# Create filters for each final state to exclude
# Exclude jobs with a terminal top-level status
queryset = queryset.exclude(status__in=JobState.final_states())

# Also exclude jobs where the "results" stage has a final state status
final_states = JobState.final_states()
exclude_conditions = Q()

# Exclude jobs where the "results" stage has a final state status
for state in final_states:
# JSON path query to check if results stage status is in final states
# @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere
Expand Down Expand Up @@ -233,6 +234,10 @@ def tasks(self, request, pk=None):
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
raise ValidationError("Only async_api jobs have fetchable tasks")

# Don't fetch tasks from completed/failed/revoked jobs
if job.status in JobState.final_states():
return Response({"tasks": []})

# Validate that the job has a pipeline
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")
Expand All @@ -241,13 +246,8 @@ def tasks(self, request, pk=None):
from ami.ml.orchestration.nats_queue import TaskQueueManager

async def get_tasks():
tasks = []
async with TaskQueueManager() as manager:
for _ in range(batch):
task = await manager.reserve_task(job.pk, timeout=0.1)
if task:
tasks.append(task.dict())
return tasks
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch)]

# Use async_to_sync to properly handle the async call
tasks = async_to_sync(get_tasks)()
Expand Down
61 changes: 28 additions & 33 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@


async def get_connection(nats_url: str):
nc = await nats.connect(nats_url)
nc = await nats.connect(
nats_url,
connect_timeout=5,
allow_reconnect=False,
max_reconnect_attempts=0,
)
js = nc.jetstream()
return nc, js

Expand All @@ -39,8 +44,8 @@ class TaskQueueManager:
Use as an async context manager:
async with TaskQueueManager() as manager:
await manager.publish_task('job123', {'data': 'value'})
task = await manager.reserve_task('job123')
await manager.acknowledge_task(task['reply_subject'])
tasks = await manager.reserve_tasks('job123', count=64)
await manager.acknowledge_task(tasks[0].reply_subject)
"""

def __init__(self, nats_url: str | None = None):
Expand Down Expand Up @@ -156,62 +161,52 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool:
logger.error(f"Failed to publish task to stream for job '{job_id}': {e}")
return False

async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None:
async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]:
"""
Reserve a task from the specified stream.
Reserve up to `count` tasks from the specified stream in a single NATS fetch.

Args:
job_id: The job ID (integer primary key) to pull tasks from
timeout: Timeout in seconds for reservation (default: 5 seconds)
count: Maximum number of tasks to reserve
timeout: Timeout in seconds waiting for messages (default: 5 seconds)

Returns:
PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available
List of PipelineProcessingTask objects with reply_subject set for acknowledgment.
May return fewer than `count` if the queue has fewer messages available.
"""
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

if timeout is None:
timeout = 5

try:
# Ensure stream and consumer exist
await self._ensure_stream(job_id)
await self._ensure_consumer(job_id)

consumer_name = self._get_consumer_name(job_id)
subject = self._get_subject(job_id)

# Create ephemeral subscription for this pull
psub = await self.js.pull_subscribe(subject, consumer_name)

try:
# Fetch a single message
msgs = await psub.fetch(1, timeout=timeout)

if msgs:
msg = msgs[0]
task_data = json.loads(msg.data.decode())
metadata = msg.metadata

# Parse the task data into PipelineProcessingTask
task = PipelineProcessingTask(**task_data)
# Set the reply_subject for acknowledgment
task.reply_subject = msg.reply

logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}")
return task

msgs = await psub.fetch(count, timeout=timeout)
except nats.errors.TimeoutError:
# No messages available
logger.debug(f"No tasks available in stream for job '{job_id}'")
return None
return []
finally:
# Always unsubscribe
await psub.unsubscribe()

tasks = []
for msg in msgs:
task_data = json.loads(msg.data.decode())
task = PipelineProcessingTask(**task_data)
task.reply_subject = msg.reply
tasks.append(task)

logger.info(f"Reserved {len(tasks)} tasks from stream for job '{job_id}'")
return tasks

except Exception as e:
logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}")
return None
logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}")
return []

async def acknowledge_task(self, reply_subject: str) -> bool:
"""
Expand Down
63 changes: 45 additions & 18 deletions ami/ml/orchestration/tests/test_nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,47 +62,74 @@ async def test_publish_task_creates_stream_and_consumer(self):
self.assertIn("job_456", str(js.add_stream.call_args))
js.add_consumer.assert_called_once()

async def test_reserve_task_success(self):
"""Test successful task reservation."""
async def test_reserve_tasks_success(self):
"""Test successful batch task reservation."""
nc, js = self._create_mock_nats_connection()
sample_task = self._create_sample_task()

# Mock message with task data
mock_msg = MagicMock()
mock_msg.data = sample_task.json().encode()
mock_msg.reply = "reply.subject.123"
mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1))
# Mock messages with task data
mock_msg1 = MagicMock()
mock_msg1.data = sample_task.json().encode()
mock_msg1.reply = "reply.subject.1"

mock_msg2 = MagicMock()
mock_msg2.data = sample_task.json().encode()
mock_msg2.reply = "reply.subject.2"

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[mock_msg])
mock_psub.fetch = AsyncMock(return_value=[mock_msg1, mock_msg2])
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
task = await manager.reserve_task(123)
tasks = await manager.reserve_tasks(123, count=5)

self.assertIsNotNone(task)
self.assertEqual(task.id, sample_task.id)
self.assertEqual(task.reply_subject, "reply.subject.123")
self.assertEqual(len(tasks), 2)
self.assertEqual(tasks[0].id, sample_task.id)
self.assertEqual(tasks[0].reply_subject, "reply.subject.1")
self.assertEqual(tasks[1].reply_subject, "reply.subject.2")
mock_psub.fetch.assert_called_once_with(5, timeout=5)
mock_psub.unsubscribe.assert_called_once()

async def test_reserve_task_no_messages(self):
"""Test reserve_task when no messages are available."""
async def test_reserve_tasks_no_messages(self):
"""Test reserve_tasks when no messages are available (timeout)."""
nc, js = self._create_mock_nats_connection()
import nats.errors

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[])
mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError)
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
task = await manager.reserve_task(123)
tasks = await manager.reserve_tasks(123, count=5)

self.assertIsNone(task)
self.assertEqual(tasks, [])
mock_psub.unsubscribe.assert_called_once()

async def test_reserve_tasks_single(self):
"""Test reserving a single task."""
nc, js = self._create_mock_nats_connection()
sample_task = self._create_sample_task()

mock_msg = MagicMock()
mock_msg.data = sample_task.json().encode()
mock_msg.reply = "reply.subject.123"

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[mock_msg])
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
tasks = await manager.reserve_tasks(123, count=1)

self.assertEqual(len(tasks), 1)
self.assertEqual(tasks[0].reply_subject, "reply.subject.123")

async def test_acknowledge_task_success(self):
"""Test successful task acknowledgment."""
nc, js = self._create_mock_nats_connection()
Expand Down Expand Up @@ -144,7 +171,7 @@ async def test_operations_without_connection_raise_error(self):
await manager.publish_task(123, sample_task)

with self.assertRaisesRegex(RuntimeError, "Connection is not open"):
await manager.reserve_task(123)
await manager.reserve_tasks(123, count=1)

with self.assertRaisesRegex(RuntimeError, "Connection is not open"):
await manager.delete_stream(123)
125 changes: 125 additions & 0 deletions docs/claude/planning/nats-flooding-prevention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# NATS Flooding Prevention & Event Loop Blocking

**Date:** 2026-02-16
**Context:** PSv2 integration test exposed Django becoming unresponsive due to NATS connection issues

## Problem

When NATS becomes temporarily unreachable or connections are interrupted, Django's entire HTTP server hangs. This was observed during integration testing when:

1. A stale job (1365) in STARTED status continuously attempted to reserve NATS tasks
2. The ADC worker spawned 16 DataLoader subprocesses, all hammering `/jobs/1365/tasks?batch=64`
3. Each `/tasks` request opens a new NATS connection and blocks a uvicorn worker thread
4. NATS connections timed out, triggering the nats.py client's reconnection loop
5. The reconnection loop consumed Django's shared event loop, blocking ALL HTTP requests
6. Even endpoints that don't use NATS (like `/ml/pipelines/`) became unreachable

## Root Causes

### 1. `nats.connect()` uses default reconnection behavior
**File:** `ami/ml/orchestration/nats_queue.py:26-29`
```python
async def get_connection(nats_url: str):
nc = await nats.connect(nats_url) # No connect_timeout, allow_reconnect defaults to True
js = nc.jetstream()
return nc, js
```

**Fix (APPLIED):** Added `connect_timeout=5, allow_reconnect=False, max_reconnect_attempts=0` to `nats.connect()`. Since we create a new connection per operation via context manager, we never need the client's built-in reconnection.

### 2. `/tasks` endpoint doesn't check job status
**File:** `ami/jobs/views.py:232-255`
The endpoint checked `dispatch_mode` but not job status. A FAILURE/SUCCESS job still tried to fetch from NATS.

**Fix (APPLIED):** Added guard: `if job.status in JobState.final_states(): return Response({"tasks": []})`.

### 3. `incomplete_only` filter only checked progress JSON, not top-level status
**File:** `ami/jobs/views.py:50-69` (`IncompleteJobFilter`)
The filter only checked the "results" stage status in the progress JSON. Jobs manually set to FAILURE (without updating progress stages) slipped through.

**Fix (APPLIED):** Added `queryset.exclude(status__in=JobState.final_states())` before the progress JSON check.

### 4. No timeout on stream/consumer operations
**File:** `ami/ml/orchestration/nats_queue.py:77-124`
`_ensure_stream()` and `_ensure_consumer()` call JetStream API without explicit timeouts. If NATS is slow, these block indefinitely.

**Status:** TODO

### 5. Leaked NATS connections from interrupted requests
When an HTTP request is interrupted (client disconnect, test script killed), the `TaskQueueManager.__aexit__` may not run, leaving a NATS connection open. With `allow_reconnect=True` (the old default), that connection's reconnection callbacks consumed the event loop.

**Status:** Mitigated by `allow_reconnect=False` fix.

### 6. `async_to_sync()` blocks Django worker threads
**Files:** `ami/jobs/views.py:253`, `ami/ml/orchestration/jobs.py:119`, `ami/jobs/tasks.py:191`

Every NATS operation wraps async code with `async_to_sync()`, which creates or reuses a thread-local event loop. If the async operation hangs (stuck NATS connection), the Django worker thread is permanently blocked.

**Status:** TODO — wrap with `asyncio.wait_for()` inside the async function.

### 7. Stale ADC workers compete for tasks (test infrastructure issue)
The test script starts an ADC worker but doesn't kill stale workers from previous runs. With 2 GPUs, `mp.spawn(nprocs=2)` forks 2 child processes. If a previous worker is still running, its DataLoader subprocesses race with the new worker for NATS messages. In the 2026-02-16 test, 147 `/tasks` requests were logged — the stale worker consumed all 20 NATS messages, leaving 0 for the new worker.

**Fix:** Add `pkill -f "ami worker"` cleanup before starting the worker in the test script.

## Additional TODOs from Integration Testing

### 7. `/tasks/` endpoint should support multiple pipelines
The endpoint should allow workers to pass in multiple pipeline slugs, or return all available tasks for projects the token has access to (no pipeline filter = all).

**Status:** TODO

### 8. ADC worker should use trailing slashes
The ADC worker requests `/api/v2/jobs/1365/tasks?batch=64` without trailing slash, causing 301 redirects. Each redirect doubles the request overhead.

**Status:** TODO (ADC-side fix in `ami-data-companion`)

### 9. `dispatch_mode` should be set on job init, not `run()`
Currently `dispatch_mode` is set when the job starts running. It should be set at job creation time so the API can filter by it before the job runs.

**Status:** TODO

### 10. Processing service online status (GitHub #1122)
Show online status of registered processing services.
**See:** https://github.com/RolnickLab/antenna/issues/1122

### 11. Show which workers pick up a job/task (GitHub #1112)
At minimum, log which worker processes each task.
**See:** https://github.com/RolnickLab/antenna/issues/1112

## Applied Changes Summary

| File | Change | Status |
|------|--------|--------|
| `ami/ml/orchestration/nats_queue.py:26-32` | `connect_timeout=5, allow_reconnect=False` | APPLIED |
| `ami/jobs/views.py:237-238` | Guard `/tasks` for terminal status jobs | APPLIED |
| `ami/jobs/views.py:59` | `incomplete_only` also checks top-level status | APPLIED |

## Remaining TODOs

| Priority | Issue | Impact |
|----------|-------|--------|
| P1 | Timeout on JetStream stream/consumer ops | Prevents indefinite blocking |
| P1 | `async_to_sync()` timeout wrapper | Prevents thread exhaustion |
| P2 | `/tasks/` multi-pipeline support | Worker efficiency |
| P2 | ADC trailing slashes | Removes 301 overhead |
| P2 | `dispatch_mode` on job init | Correct filtering at creation time |
| P3 | Stale job auto-cleanup (Celery Beat) | Prevents future flooding |
| P3 | Circuit breaker for NATS failures | Graceful degradation |
| P3 | #1122: Processing service online status | UX |
| P3 | #1112: Worker tracking in logs | Observability |

## Related Files

| File | Lines | What |
|------|-------|------|
| `ami/ml/orchestration/nats_queue.py` | 26-32 | `get_connection()` — FIXED with timeouts |
| `ami/ml/orchestration/nats_queue.py` | 77-124 | Stream/consumer ops — needs timeouts |
| `ami/ml/orchestration/nats_queue.py` | 159-214 | `reserve_task()` — has timeout but connection may block |
| `ami/jobs/views.py` | 50-69 | `IncompleteJobFilter` — FIXED |
| `ami/jobs/views.py` | 237-238 | `/tasks` status guard — FIXED |
| `ami/jobs/views.py` | 243-256 | `/tasks/` endpoint — `async_to_sync()` blocks thread |
| `ami/ml/orchestration/jobs.py` | 119 | `queue_images_to_nats()` — `async_to_sync()` blocks thread |
| `ami/jobs/tasks.py` | 184-199 | `_ack_task_via_nats()` — per-ACK connection (expensive) |
| `docs/claude/debugging/nats-triage.md` | Full | Previous NATS debugging findings |
| `docs/claude/nats-todo.md` | Full | NATS infrastructure improvements tracker |
Loading