Skip to content

Commit bfb29e2

Browse files
authored
Merge branch 'main' into fix/2digit-year-timestamps-and-sync-improvements
2 parents bca4478 + 0487cf9 commit bfb29e2

74 files changed

Lines changed: 1766 additions & 786 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ami/base/permissions.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,73 @@ def add_collection_level_permissions(user: User | None, response_data: dict, mod
7777
return response_data
7878

7979

80+
def add_m2m_object_permissions(user, instance, project, response_data: dict) -> dict:
81+
"""
82+
Add object-level permissions for models with an M2M relationship to Project.
83+
84+
The default permission resolution (BaseModel._get_object_perms) relies on
85+
get_project(), which returns None for M2M-to-Project models (TaxaList, etc.)
86+
because there's no single owning project. This function resolves permissions
87+
against a specific project from the request context instead.
88+
89+
Validates that the instance actually belongs to the given project before
90+
granting any permissions (prevents cross-project permission leaks).
91+
92+
This is a temporary approach for the M2M permission gap described in #1120.
93+
Once that issue is resolved, this should be replaced by a generic permission
94+
class (Pattern B: Bare M2M) that handles TaxaList, Taxon, ProcessingService,
95+
Pipeline, and other M2M-to-Project models uniformly.
96+
"""
97+
perms = set(response_data.get("user_permissions", []))
98+
99+
if not project or not instance.projects.filter(pk=project.pk).exists():
100+
response_data["user_permissions"] = list(perms)
101+
return response_data
102+
103+
if user.is_superuser:
104+
perms.update(["update", "delete"])
105+
else:
106+
model_name = instance._meta.model_name
107+
all_perms = get_perms(user, project)
108+
for perm in all_perms:
109+
if perm.endswith(f"_{model_name}"):
110+
action = perm.split("_", 1)[0]
111+
if action in {"update", "delete"}:
112+
perms.add(action)
113+
114+
response_data["user_permissions"] = list(perms)
115+
return response_data
116+
117+
118+
class IsProjectMemberOrReadOnly(permissions.BasePermission):
119+
"""
120+
Safe methods are allowed for everyone.
121+
Unsafe methods (POST, PUT, PATCH, DELETE) require the requesting user to be
122+
a member of the active project (resolved via ProjectMixin.get_active_project).
123+
"""
124+
125+
def has_permission(self, request, view):
126+
if request.method in permissions.SAFE_METHODS:
127+
return True
128+
129+
if not request.user or not request.user.is_authenticated:
130+
return False
131+
132+
if request.user.is_superuser: # type: ignore[union-attr]
133+
return True
134+
135+
# view must provide get_active_project (i.e. use ProjectMixin)
136+
get_active_project = getattr(view, "get_active_project", None)
137+
if not get_active_project:
138+
return False
139+
140+
project = get_active_project()
141+
if not project:
142+
return False
143+
144+
return project.members.filter(pk=request.user.pk).exists()
145+
146+
80147
class ObjectPermission(permissions.BasePermission):
81148
"""
82149
Generic permission class that delegates to the model's `check_permission(user, action)` method.

ami/jobs/tasks.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,13 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
8484

8585
state_manager = AsyncJobStateManager(job_id)
8686

87-
progress_info = state_manager.update_state(
88-
processed_image_ids, stage="process", request_id=self.request.id, failed_image_ids=failed_image_ids
89-
)
87+
progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids)
9088
if not progress_info:
91-
logger.warning(
92-
f"Another task is already processing results for job {job_id}. "
93-
f"Retrying task {self.request.id} in 5 seconds..."
94-
)
95-
raise self.retry(countdown=5, max_retries=10)
89+
logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
90+
# Acknowledge the task to prevent retries, since we don't know the state
91+
_ack_task_via_nats(reply_subject, logger)
92+
# TODO: cancel the job to fail fast once PR #1144 is merged
93+
return
9694

9795
try:
9896
complete_state = JobState.SUCCESS
@@ -126,6 +124,7 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
126124
_ack_task_via_nats(reply_subject, logger)
127125
return
128126

127+
acked = False
129128
try:
130129
# Save to database (this is the slow operation)
131130
detections_count, classifications_count, captures_count = 0, 0, 0
@@ -145,20 +144,18 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
145144
captures_count = len(pipeline_result.source_images)
146145

147146
_ack_task_via_nats(reply_subject, job.logger)
147+
acked = True
148148
# Update job stage with calculated progress
149149

150150
progress_info = state_manager.update_state(
151151
processed_image_ids,
152152
stage="results",
153-
request_id=self.request.id,
154153
)
155154

156155
if not progress_info:
157-
logger.warning(
158-
f"Another task is already processing results for job {job_id}. "
159-
f"Retrying task {self.request.id} in 5 seconds..."
160-
)
161-
raise self.retry(countdown=5, max_retries=10)
156+
job.logger.error(f"Redis state missing for job {job_id} — job may have been cleaned up prematurely.")
157+
# TODO: cancel the job to fail fast once PR #1144 is merged
158+
return
162159

163160
# update complete state based on latest progress info after saving results
164161
complete_state = JobState.SUCCESS
@@ -176,9 +173,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
176173
)
177174

178175
except Exception as e:
179-
job.logger.error(
180-
f"Failed to process pipeline result for job {job_id}: {e}. NATS will redeliver the task message."
181-
)
176+
error = f"Error processing pipeline result for job {job_id}: {e}"
177+
if not acked:
178+
error += ". NATS will re-deliver the task message."
179+
180+
job.logger.error(error)
182181

183182

184183
def _ack_task_via_nats(reply_subject: str, job_logger: logging.Logger) -> None:
@@ -256,9 +255,33 @@ def _update_job_progress(
256255
state_params["classifications"] = current_classifications + new_classifications
257256
state_params["captures"] = current_captures + new_captures
258257

258+
# Don't overwrite a stage with a stale progress value.
259+
# This guards against the race where a slower worker calls _update_job_progress
260+
# after a faster worker has already marked further progress.
261+
try:
262+
existing_stage = job.progress.get_stage(stage)
263+
progress_percentage = max(existing_stage.progress, progress_percentage)
264+
# Explicitly preserve FAILURE: once a stage is marked FAILURE it should
265+
# never regress to a non-failure state, regardless of enum ordering.
266+
if existing_stage.status == JobState.FAILURE:
267+
complete_state = JobState.FAILURE
268+
except (ValueError, AttributeError):
269+
pass # Stage doesn't exist yet; proceed normally
270+
271+
# Determine the status to write:
272+
# - Stage complete (100%): use complete_state (SUCCESS or FAILURE)
273+
# - Stage incomplete but FAILURE already determined: keep FAILURE visible
274+
# - Stage incomplete, no failure: mark as in-progress (STARTED)
275+
if progress_percentage >= 1.0:
276+
status = complete_state
277+
elif complete_state == JobState.FAILURE:
278+
status = JobState.FAILURE
279+
else:
280+
status = JobState.STARTED
281+
259282
job.progress.update_stage(
260283
stage,
261-
status=complete_state if progress_percentage >= 1.0 else JobState.STARTED,
284+
status=status,
262285
progress=progress_percentage,
263286
**state_params,
264287
)

ami/jobs/test_tasks.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,26 @@
66
"""
77

88
import logging
9+
from concurrent.futures import ThreadPoolExecutor
910
from unittest.mock import AsyncMock, MagicMock, patch
1011

1112
from django.core.cache import cache
12-
from django.test import TestCase
13+
from django.test import TransactionTestCase
1314
from rest_framework.test import APITestCase
1415

1516
from ami.base.serializers import reverse_with_params
1617
from ami.jobs.models import Job, JobDispatchMode, JobState, MLJob
1718
from ami.jobs.tasks import process_nats_pipeline_result
1819
from ami.main.models import Detection, Project, SourceImage, SourceImageCollection
1920
from ami.ml.models import Pipeline
20-
from ami.ml.orchestration.async_job_state import AsyncJobStateManager, _lock_key
21+
from ami.ml.orchestration.async_job_state import AsyncJobStateManager
2122
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse, SourceImageResponse
2223
from ami.users.models import User
2324

2425
logger = logging.getLogger(__name__)
2526

2627

27-
class TestProcessNatsPipelineResultError(TestCase):
28+
class TestProcessNatsPipelineResultError(TransactionTestCase):
2829
"""E2E tests for process_nats_pipeline_result with error handling."""
2930

3031
def setUp(self):
@@ -237,38 +238,46 @@ def test_process_nats_pipeline_result_mixed_results(self, mock_manager_class):
237238
self.assertEqual(mock_manager.acknowledge_task.call_count, 3)
238239

239240
@patch("ami.jobs.tasks.TaskQueueManager")
240-
def test_process_nats_pipeline_result_error_concurrent_locking(self, mock_manager_class):
241+
def test_process_nats_pipeline_result_concurrent_updates(self, mock_manager_class):
241242
"""
242-
Test that error results respect locking mechanism.
243+
Test that concurrent workers update state independently without contention.
243244
244-
Verifies race condition handling when multiple workers
245-
process error results simultaneously.
245+
Without a lock, two workers processing different images can both call
246+
update_state and receive valid progress — no retry needed, no blocking.
246247
"""
247-
# Simulate lock held by another task
248-
lock_key = _lock_key(self.job.pk)
249-
cache.set(lock_key, "other-task-id", timeout=60)
250-
251-
# Create error result
252-
error_data = self._create_error_result(image_id=str(self.images[0].pk))
253-
reply_subject = "tasks.reply.test789"
248+
mock_manager = self._setup_mock_nats(mock_manager_class)
254249

255-
# Task should raise retry exception when lock not acquired
256-
# The task internally calls self.retry() which raises a Retry exception
257-
from celery.exceptions import Retry
250+
with ThreadPoolExecutor(max_workers=2) as executor:
251+
# Worker 1 processes images[0]
252+
result_1 = executor.submit(
253+
process_nats_pipeline_result.apply,
254+
kwargs={
255+
"job_id": self.job.pk,
256+
"result_data": self._create_error_result(image_id=str(self.images[0].pk)),
257+
"reply_subject": "reply.concurrent.1",
258+
},
259+
)
258260

259-
with self.assertRaises(Retry):
260-
process_nats_pipeline_result.apply(
261+
# Worker 2 processes images[1] — no retry, no lock to wait for
262+
result_2 = executor.submit(
263+
process_nats_pipeline_result.apply,
261264
kwargs={
262265
"job_id": self.job.pk,
263-
"result_data": error_data,
264-
"reply_subject": reply_subject,
265-
}
266+
"result_data": self._create_error_result(image_id=str(self.images[1].pk)),
267+
"reply_subject": "reply.concurrent.2",
268+
},
266269
)
267270

268-
# Assert: Progress was NOT updated (lock not acquired)
271+
self.assertTrue(result_1.result().successful())
272+
self.assertTrue(result_2.result().successful())
273+
274+
# Both images should be marked as processed
269275
manager = AsyncJobStateManager(self.job.pk)
270276
progress = manager.get_progress("process")
271-
self.assertEqual(progress.processed, 0)
277+
self.assertIsNotNone(progress)
278+
self.assertEqual(progress.processed, 2)
279+
self.assertEqual(progress.total, 3)
280+
self.assertEqual(mock_manager.acknowledge_task.call_count, 2)
272281

273282
@patch("ami.jobs.tasks.TaskQueueManager")
274283
def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class):

ami/main/api/serializers.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from rest_framework.request import Request
77

88
from ami.base.fields import DateStringField
9+
from ami.base.permissions import add_m2m_object_permissions
910
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, reverse_with_params
1011
from ami.base.views import get_active_project
1112
from ami.jobs.models import Job
@@ -633,13 +634,23 @@ def get_occurrences(self, obj):
633634
)
634635

635636

636-
class TaxaListSerializer(serializers.ModelSerializer):
637+
class TaxaListSerializer(DefaultSerializer):
637638
taxa = serializers.SerializerMethodField()
638-
projects = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all(), many=True)
639+
taxa_count = serializers.SerializerMethodField()
640+
projects = serializers.SerializerMethodField()
639641

640642
class Meta:
641643
model = TaxaList
642-
fields = ["id", "name", "description", "taxa", "projects"]
644+
fields = [
645+
"id",
646+
"name",
647+
"description",
648+
"taxa",
649+
"taxa_count",
650+
"projects",
651+
"created_at",
652+
"updated_at",
653+
]
643654

644655
def get_taxa(self, obj):
645656
"""
@@ -651,6 +662,43 @@ def get_taxa(self, obj):
651662
params={"taxa_list_id": obj.pk},
652663
)
653664

665+
def get_taxa_count(self, obj):
666+
"""
667+
Return the number of taxa in this list.
668+
Uses annotated_taxa_count if available (from ViewSet) for performance.
669+
"""
670+
return getattr(obj, "annotated_taxa_count", obj.taxa.count())
671+
672+
def get_permissions(self, instance, instance_data):
673+
request = self.context["request"]
674+
project = get_active_project(request=request)
675+
return add_m2m_object_permissions(request.user, instance, project, instance_data)
676+
677+
def get_projects(self, obj):
678+
"""
679+
Return list of project IDs this taxa list belongs to.
680+
This is read-only and managed by the server.
681+
"""
682+
return list(obj.projects.values_list("id", flat=True))
683+
684+
685+
class TaxaListTaxonInputSerializer(serializers.Serializer):
686+
"""Serializer for adding a taxon to a taxa list."""
687+
688+
taxon_id = serializers.IntegerField(required=True)
689+
690+
def validate_taxon_id(self, value):
691+
"""Validate that the taxon exists."""
692+
if not Taxon.objects.filter(id=value).exists():
693+
raise serializers.ValidationError("Taxon does not exist.")
694+
return value
695+
696+
697+
class TaxaListTaxonSerializer(TaxonNoParentNestedSerializer):
698+
"""Serializer for taxa in a taxa list (simplified taxon representation)."""
699+
700+
pass
701+
654702

655703
class CaptureTaxonSerializer(DefaultSerializer):
656704
parent = TaxonNoParentNestedSerializer(read_only=True)

0 commit comments

Comments
 (0)