Skip to content
Open
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f46e88c
feat: add base and runner classes for generic post-processing framework
mohamedelabbas1996 Sep 18, 2025
d86ea4d
feat: add post-processing framework base post-processing task class
mohamedelabbas1996 Sep 30, 2025
2c0f78f
feat: add small size filter post-processing task class
mohamedelabbas1996 Sep 30, 2025
ffba709
feat: add post processing job type
mohamedelabbas1996 Sep 30, 2025
63cd84b
feat: trigger small size filter post processing task from admin page
mohamedelabbas1996 Sep 30, 2025
cab62bf
feat: add a new algorithm task type for post-processing
mohamedelabbas1996 Sep 30, 2025
6d0e284
chore: deleted runner.py
mohamedelabbas1996 Sep 30, 2025
4cfe2d8
feat: add migration for creating a new job type
mohamedelabbas1996 Sep 30, 2025
b42e069
fix: fix an import error with the AlgorithmTaskType
mohamedelabbas1996 Sep 30, 2025
cb7c83a
feat: update identification history of occurrences in SmallSizeFilter
mohamedelabbas1996 Oct 2, 2025
10103db
feat: add rank rollup
mohamedelabbas1996 Oct 6, 2025
2e81d90
feat: add class masking post processing task
mohamedelabbas1996 Oct 7, 2025
0baf8ce
feat: trigger class masking from admin page
mohamedelabbas1996 Oct 7, 2025
f3caa18
fix: modified log messages
mohamedelabbas1996 Oct 8, 2025
65d4fef
fix: set the classification algorithm to the rank rollup Algorithm w…
mohamedelabbas1996 Oct 8, 2025
e13afc1
feat: trigger rank rollup from admin page
mohamedelabbas1996 Oct 8, 2025
7ecc18c
Remove class_masking.py from framework branch
mohamedelabbas1996 Oct 14, 2025
f214025
fix: initialize post-processing tasks with job context and simplify r…
mohamedelabbas1996 Oct 14, 2025
20ff4b6
feat: add permission to run post-processing jobs
mohamedelabbas1996 Oct 14, 2025
5b66ae3
chore: remove class_masking import
mohamedelabbas1996 Oct 14, 2025
0419eff
refactor: redesign BasePostProcessingTask with job-aware logging, pro…
mohamedelabbas1996 Oct 14, 2025
1ad1e76
refactor: adapt RankRollupTask to new BasePostProcessingTask with sel…
mohamedelabbas1996 Oct 14, 2025
d97e8e0
refactor: update SmallSizeFilter to use BasePostProcessingTask loggin…
mohamedelabbas1996 Oct 14, 2025
2922c86
migrations: update Project options to include post-processing job per…
mohamedelabbas1996 Oct 14, 2025
9012d7f
migrations: update Algorithm.task_type choices to include post-proces…
mohamedelabbas1996 Oct 14, 2025
319bb3d
Merge branch 'main' into feat/postprocessing-framework
mohamedelabbas1996 Oct 14, 2025
787ac0b
migrations: merged migrations
mohamedelabbas1996 Oct 14, 2025
5e85b75
refactor: refactor job runner to initialize post-processing tasks wit…
mohamedelabbas1996 Oct 10, 2025
88ffba8
chore: rebase feat/postprocessing-class-masking onto feat/postprocess…
mohamedelabbas1996 Oct 14, 2025
9519600
chore: remove class masking trigger (moved to feat/postprocessing-cla…
mohamedelabbas1996 Oct 14, 2025
21e6648
feat: improved progress tracking
mohamedelabbas1996 Oct 14, 2025
7135e15
Merge branch 'feat/postprocessing-framework' into feat/postprocessing…
mohamedelabbas1996 Oct 14, 2025
6632c31
feat: add applied_to field to Classification to track source classifi…
mohamedelabbas1996 Oct 15, 2025
23f80fb
tests: added tests for small size filter and rank roll up post-proces…
mohamedelabbas1996 Oct 15, 2025
336636a
fix: create only terminal classifications and remove identification c…
mohamedelabbas1996 Oct 15, 2025
0d90cde
refactor: remove inner transaction.atomic for cleaner transaction man…
mohamedelabbas1996 Oct 15, 2025
23469e2
tests: fixed small size filter test
mohamedelabbas1996 Oct 15, 2025
001464e
Merge branch 'feat/postprocessing-framework' into feat/postprocessing…
mohamedelabbas1996 Oct 15, 2025
916d652
Merge branch 'main' of github.com:RolnickLab/antenna into feat/postpr…
mihow Oct 16, 2025
1b8700e
draft: work towards class masking in new framework
mihow Oct 16, 2025
e4639f6
Merge remote-tracking branch 'origin/main' into feat/postprocessing-c…
mihow Feb 18, 2026
a466a52
feat: add class masking tests, management command, and fix registry
mihow Feb 18, 2026
a107597
fix: address review feedback on class masking and rank rollup
mihow Feb 18, 2026
da9b081
feat: replace hardcoded admin action with dynamic class masking form
mihow Feb 18, 2026
fc3f9e1
docs: add class masking screenshots for PR review
mihow Feb 18, 2026
c96a865
fix: address review feedback — N+1 query, distinct, HTML, test ordering
mihow Feb 18, 2026
6be1239
feat: expose applied_to field in Classification API serializers
mihow Feb 18, 2026
c4311aa
feat: make applied_to a nested object with algorithm details
mihow Feb 18, 2026
daed538
fix: add prefetch for applied_to on occurrence detail endpoint
mihow Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.template.defaultfilters import filesizeformat
from django.urls import reverse
from django.utils.formats import number_format
from django.utils.html import format_html
from guardian.admin import GuardedModelAdmin

import ami.utils
from ami import tasks
from ami.jobs.models import Job
from ami.ml.models.algorithm import Algorithm
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.class_masking import update_single_occurrence
from ami.ml.tasks import remove_duplicate_classifications

from .models import (
Expand Down Expand Up @@ -288,20 +292,29 @@ class ClassificationInline(admin.TabularInline):
model = Classification
extra = 0
fields = (
"classification_link",
"taxon",
"algorithm",
"timestamp",
"terminal",
"created_at",
)
readonly_fields = (
"classification_link",
"taxon",
"algorithm",
"timestamp",
"terminal",
"created_at",
)

@admin.display(description="Classification")
def classification_link(self, obj: Classification) -> str:
if obj.pk:
url = reverse("admin:main_classification_change", args=[obj.pk])
return format_html('<a href="{}">{}</a>', url, f"Classification #{obj.pk}")
return "-"

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related("taxon", "algorithm", "detection")
Expand All @@ -311,20 +324,29 @@ class DetectionInline(admin.TabularInline):
model = Detection
extra = 0
fields = (
"detection_link",
"detection_algorithm",
"source_image",
"timestamp",
"created_at",
"occurrence",
)
readonly_fields = (
"detection_link",
"detection_algorithm",
"source_image",
"timestamp",
"created_at",
"occurrence",
)

@admin.display(description="ID")
def detection_link(self, obj):
if obj.pk:
url = reverse("admin:main_detection_change", args=[obj.pk])
return format_html('<a href="{}">{}</a>', url, obj.pk)
return "-"


@admin.register(Detection)
class DetectionAdmin(admin.ModelAdmin[Detection]):
Expand Down Expand Up @@ -382,7 +404,7 @@ class OccurrenceAdmin(admin.ModelAdmin[Occurrence]):
"determination__rank",
"created_at",
)
search_fields = ("determination__name", "determination__search_names")
search_fields = ("id", "determination__name", "determination__search_names")

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
Expand All @@ -404,11 +426,60 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
def detections_count(self, obj) -> int:
return obj.detections_count

@admin.action(description="Update occurrence with Newfoundland species taxa list")
def update_with_newfoundland_species(self, request: HttpRequest, queryset: QuerySet[Occurrence]) -> None:
"""
Update selected occurrences using the 'Newfoundland species' taxa list
and 'Quebec & Vermont Species Classifier - Apr 2024' algorithm.
"""
try:
# Get the taxa list by name
taxa_list = TaxaList.objects.get(name="Newfoundland Species")
except TaxaList.DoesNotExist:
self.message_user(
request,
"Error: TaxaList 'Newfoundland species' not found.",
level="error",
)
return

try:
# Get the algorithm by name
algorithm = Algorithm.objects.get(name="Quebec & Vermont Species Classifier - Apr 2024")
except Algorithm.DoesNotExist:
self.message_user(
request,
"Error: Algorithm 'Quebec & Vermont Species Classifier - Apr 2024' not found.",
level="error",
)
return

# Process each occurrence
count = 0
for occurrence in queryset:
try:
update_single_occurrence(
occurrence=occurrence,
algorithm=algorithm,
taxa_list=taxa_list,
)
count += 1
except Exception as e:
self.message_user(
request,
f"Error processing occurrence {occurrence.pk}: {str(e)}",
level="error",
)

self.message_user(request, f"Successfully updated {count} occurrence(s).")

ordering = ("-created_at",)

# Add classifications as inline
inlines = [DetectionInline]

actions = [update_with_newfoundland_species]


@admin.register(Classification)
class ClassificationAdmin(admin.ModelAdmin[Classification]):
Expand All @@ -432,6 +503,8 @@ class ClassificationAdmin(admin.ModelAdmin[Classification]):
"taxon__rank",
)

autocomplete_fields = ("taxon",)

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
return qs.select_related(
Expand Down Expand Up @@ -662,10 +735,32 @@ def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceI

self.message_user(request, f"Queued Small Size Filter for {queryset.count()} collection(s). Jobs: {jobs}")

@admin.action(description="Run Rank Rollup post-processing task (async)")
def run_rank_rollup(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
"""Trigger the Rank Rollup post-processing job asynchronously."""
jobs = []
DEFAULT_THRESHOLDS = {"species": 0.8, "genus": 0.6, "family": 0.4}

for collection in queryset:
job = Job.objects.create(
name=f"Post-processing: RankRollup on Collection {collection.pk}",
project=collection.project,
job_type_key="post_processing",
params={
"task": "rank_rollup",
"config": {"source_image_collection_id": collection.pk, "thresholds": DEFAULT_THRESHOLDS},
},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Rank Rollup for {queryset.count()} collection(s). Jobs: {jobs}")

actions = [
populate_collection,
populate_collection_async,
run_small_size_filter,
run_rank_rollup,
]

# Hide images many-to-many field from form. This would list all source images in the database.
Expand Down
83 changes: 83 additions & 0 deletions ami/ml/management/commands/run_class_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from django.core.management.base import BaseCommand, CommandError

from ami.main.models import SourceImageCollection, TaxaList
from ami.ml.models.algorithm import Algorithm
from ami.ml.post_processing.class_masking import ClassMaskingTask


class Command(BaseCommand):
help = (
"Run class masking post-processing on a source image collection. "
"Masks classifier logits for species not in the given taxa list and recalculates softmax scores."
)

def add_arguments(self, parser):
parser.add_argument("--collection-id", type=int, required=True, help="SourceImageCollection ID to process")
parser.add_argument("--taxa-list-id", type=int, required=True, help="TaxaList ID to use as the species mask")
parser.add_argument(
"--algorithm-id", type=int, required=True, help="Algorithm ID whose classifications to mask"
)
parser.add_argument("--dry-run", action="store_true", help="Show what would be done without making changes")

def handle(self, *args, **options):
collection_id = options["collection_id"]
taxa_list_id = options["taxa_list_id"]
algorithm_id = options["algorithm_id"]
dry_run = options["dry_run"]

# Validate inputs
try:
collection = SourceImageCollection.objects.get(pk=collection_id)
except SourceImageCollection.DoesNotExist:
raise CommandError(f"SourceImageCollection {collection_id} does not exist.")

try:
taxa_list = TaxaList.objects.get(pk=taxa_list_id)
except TaxaList.DoesNotExist:
raise CommandError(f"TaxaList {taxa_list_id} does not exist.")

try:
algorithm = Algorithm.objects.get(pk=algorithm_id)
except Algorithm.DoesNotExist:
raise CommandError(f"Algorithm {algorithm_id} does not exist.")

if not algorithm.category_map:
raise CommandError(f"Algorithm '{algorithm.name}' does not have a category map.")

from ami.main.models import Classification

classification_count = (
Classification.objects.filter(
detection__source_image__collections=collection,
terminal=True,
algorithm=algorithm,
scores__isnull=False,
)
.distinct()
.count()
)

taxa_count = taxa_list.taxa.count()

self.stdout.write(
f"Collection: {collection.name} (id={collection.pk})\n"
f"Taxa list: {taxa_list.name} (id={taxa_list.pk}, {taxa_count} taxa)\n"
f"Algorithm: {algorithm.name} (id={algorithm.pk})\n"
f"Classifications to process: {classification_count}"
)

if classification_count == 0:
raise CommandError("No terminal classifications with scores found for this collection/algorithm.")

if dry_run:
self.stdout.write(self.style.WARNING("Dry run — no changes made."))
return

self.stdout.write("Running class masking...")
task = ClassMaskingTask(
collection_id=collection_id,
taxa_list_id=taxa_list_id,
algorithm_id=algorithm_id,
)
task.run()
self.stdout.write(self.style.SUCCESS("Class masking completed."))
1 change: 0 additions & 1 deletion ami/ml/post_processing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from . import small_size_filter # noqa: F401
Loading