diff --git a/ami/jobs/migrations/0013_add_job_logs.py b/ami/jobs/migrations/0013_add_job_logs.py
new file mode 100644
index 000000000..5a8ab2e9d
--- /dev/null
+++ b/ami/jobs/migrations/0013_add_job_logs.py
@@ -0,0 +1,53 @@
+# Generated by Django 4.2.10 on 2024-12-17 20:01
+
+import django_pydantic_field.fields
+from django.db import migrations
+
+import ami.jobs.models
+
+
+def migrate_logs_forward(apps, schema_editor):
+ """Move logs from Job.progress to Job.logs"""
+ Job = apps.get_model("jobs", "Job")
+ jobs_to_update = []
+ for job in Job.objects.filter(progress__isnull=False):
+ if job.progress.logs or job.progress.errors:
+ # Move logs from progress to the new logs field
+ job.logs.stdout = job.progress.logs
+ job.logs.stderr = job.progress.errors
+ jobs_to_update.append(job)
+ # Update all jobs in a single query
+ Job.objects.bulk_update(jobs_to_update, ["logs"])
+
+
+def migrate_logs_backward(apps, schema_editor):
+ """Move logs from Job.logs back to Job.progress"""
+ Job = apps.get_model("jobs", "Job")
+ jobs_to_update = []
+ for job in Job.objects.filter(logs__isnull=False):
+ # Move logs back to progress
+ job.progress.logs = job.logs.stdout
+ job.progress.errors = job.logs.stderr
+ jobs_to_update.append(job)
+ # Update all jobs in a single query
+ Job.objects.bulk_update(jobs_to_update, ["progress"])
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("jobs", "0012_alter_job_limit"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="job",
+ name="logs",
+ field=django_pydantic_field.fields.PydanticSchemaField(
+ config=None, default={"stderr": [], "stdout": []}, schema=ami.jobs.models.JobLogs
+ ),
+ ),
+ migrations.RunPython(
+ migrate_logs_forward,
+ migrate_logs_backward,
+ ),
+ ]
diff --git a/ami/jobs/migrations/0014_alter_job_progress.py b/ami/jobs/migrations/0014_alter_job_progress.py
new file mode 100644
index 000000000..09c9b756a
--- /dev/null
+++ b/ami/jobs/migrations/0014_alter_job_progress.py
@@ -0,0 +1,23 @@
+# Generated by Django 4.2.10 on 2024-12-17 20:13
+
+import ami.jobs.models
+from django.db import migrations
+import django_pydantic_field.fields
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("jobs", "0013_add_job_logs"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="job",
+ name="progress",
+ field=django_pydantic_field.fields.PydanticSchemaField(
+ config=None,
+ default={"errors": [], "logs": [], "stages": [], "summary": {"progress": 0.0, "status": "CREATED"}},
+ schema=ami.jobs.models.JobProgress,
+ ),
+ ),
+ ]
diff --git a/ami/jobs/migrations/0015_merge_20250117_2100.py b/ami/jobs/migrations/0015_merge_20250117_2100.py
new file mode 100644
index 000000000..9b04fd33a
--- /dev/null
+++ b/ami/jobs/migrations/0015_merge_20250117_2100.py
@@ -0,0 +1,12 @@
+# Generated by Django 4.2.10 on 2025-01-17 21:00
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("jobs", "0013_merge_0011_alter_job_limit_0012_alter_job_limit"),
+ ("jobs", "0014_alter_job_progress"),
+ ]
+
+ operations = []
diff --git a/ami/jobs/models.py b/ami/jobs/models.py
index 3bd192106..e839c7737 100644
--- a/ami/jobs/models.py
+++ b/ami/jobs/models.py
@@ -85,14 +85,10 @@ class JobProgressSummary(pydantic.BaseModel):
status: JobState = JobState.CREATED
progress: float = 0
- status_label: str = ""
- @pydantic.validator("status_label", always=True)
- def serialize_status_label(cls, value, values) -> str:
- if "status" not in values or "progress" not in values:
- # Does this happen if status label gets initialized before status and progress?
- return ""
- return get_status_label(values["status"], values["progress"])
+ @property
+ def status_label(self) -> str:
+ return get_status_label(self.status, self.progress)
class Config:
use_enum_values = True
@@ -112,15 +108,15 @@ class JobProgress(pydantic.BaseModel):
summary: JobProgressSummary
stages: list[JobProgressStageDetail]
- errors: list[str] = []
- logs: list[str] = []
+ errors: list[str] = [] # Deprecated, @TODO remove in favor of logs.stderr
+ logs: list[str] = [] # Deprecated, @TODO remove in favor of logs.stdout
- def get_stage_key(self, name: str) -> str:
+ def make_key(self, name: str) -> str:
"""Generate a key for a stage or param based on its name"""
return python_slugify(name)
def add_stage(self, name: str, key: str | None = None) -> JobProgressStageDetail:
- key = key or self.get_stage_key(name)
+ key = key or self.make_key(name)
try:
return self.get_stage(key)
except ValueError:
@@ -146,26 +142,28 @@ def get_stage_param(self, stage_key: str, param_key: str) -> ConfigurableStagePa
return param
raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'")
- def add_stage_param(self, stage_key: str, name: str, value: typing.Any = None) -> ConfigurableStageParam:
+ def add_stage_param(self, stage_key: str, param_name: str, value: typing.Any = None) -> ConfigurableStageParam:
stage = self.get_stage(stage_key)
try:
- return self.get_stage_param(stage_key, self.get_stage_key(name))
+ return self.get_stage_param(stage_key, self.make_key(param_name))
except ValueError:
param = ConfigurableStageParam(
- name=name,
- key=self.get_stage_key(name),
+ name=param_name,
+ key=self.make_key(param_name),
value=value,
)
stage.params.append(param)
return param
- def add_or_update_stage_param(self, stage_key: str, name: str, value: typing.Any = None) -> ConfigurableStageParam:
+ def add_or_update_stage_param(
+ self, stage_key: str, param_name: str, value: typing.Any = None
+ ) -> ConfigurableStageParam:
try:
- param = self.get_stage_param(stage_key, self.get_stage_key(name))
+ param = self.get_stage_param(stage_key, self.make_key(param_name))
param.value = value
return param
except ValueError:
- return self.add_stage_param(stage_key, name, value)
+ return self.add_stage_param(stage_key, param_name, value)
def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgressStageDetail | None:
""" "
@@ -176,7 +174,7 @@ def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgres
This is the preferred method to update a stage's parameters.
"""
- stage_key = self.get_stage_key(stage_key_or_name) # Allow both title or key to be used for lookup
+ stage_key = self.make_key(stage_key_or_name) # Allow both title or key to be used for lookup
stage = self.get_stage(stage_key)
if stage.key == stage_key:
@@ -243,6 +241,11 @@ def default_ml_job_progress() -> JobProgress:
)
+class JobLogs(pydantic.BaseModel):
+ stdout: list[str] = pydantic.Field(default_factory=list, alias="stdout", title="All messages")
+ stderr: list[str] = pydantic.Field(default_factory=list, alias="stderr", title="Error messages")
+
+
class JobLogHandler(logging.Handler):
"""
Class for handling logs from a job and writing them to the job instance.
@@ -254,24 +257,26 @@ def __init__(self, job: "Job", *args, **kwargs):
self.job = job
super().__init__(*args, **kwargs)
- def emit(self, record):
+ def emit(self, record: logging.LogRecord):
# Log to the current app logger
logger.log(record.levelno, self.format(record))
# Write to the logs field on the job instance
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
msg = f"[{timestamp}] {record.levelname} {self.format(record)}"
- if msg not in self.job.progress.logs:
- self.job.progress.logs.insert(0, msg)
+ if msg not in self.job.logs.stdout:
+ self.job.logs.stdout.insert(0, msg)
# Write a simpler copy of any errors to the errors field
if record.levelno >= logging.ERROR:
- if record.message not in self.job.progress.errors:
- self.job.progress.errors.insert(0, record.message)
+ if record.message not in self.job.logs.stderr:
+ self.job.logs.stderr.insert(0, record.message)
- if len(self.job.progress.logs) > self.max_log_length:
- self.job.progress.logs = self.job.progress.logs[: self.max_log_length]
- self.job.save()
+ if len(self.job.logs.stdout) > self.max_log_length:
+ self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length]
+
+ # @TODO consider saving logs to the database periodically rather than on every log
+ self.job.save(update_fields=["logs"], update_progress=False)
@dataclass
@@ -312,6 +317,10 @@ def run(cls, job: "Job"):
job.finished_at = None
job.save()
+ # Keep track of sub-tasks for saving results, pair with batch number
+ save_tasks: list[tuple[int, AsyncResult]] = []
+ save_tasks_completed: list[tuple[int, AsyncResult]] = []
+
if job.delay:
update_interval_seconds = 2
last_update = time.time()
@@ -337,91 +346,155 @@ def run(cls, job: "Job"):
)
job.save()
- if job.pipeline:
- job.progress.update_stage(
- "collect",
- status=JobState.STARTED,
- progress=0,
- )
+ if not job.pipeline:
+ raise ValueError("No pipeline specified to process images in ML job")
- images = list(
- # @TODO return generator plus image count
- # @TODO pass to celery group chain?
- job.pipeline.collect_images(
- collection=job.source_image_collection,
- deployment=job.deployment,
- source_images=[job.source_image_single] if job.source_image_single else None,
- job_id=job.pk,
- skip_processed=True,
- # shuffle=job.shuffle,
- )
- )
- source_image_count = len(images)
- job.progress.update_stage("collect", total_images=source_image_count)
-
- if job.shuffle and source_image_count > 1:
- job.logger.info("Shuffling images")
- random.shuffle(images)
-
- if job.limit and source_image_count > job.limit:
- job.logger.warn(f"Limiting number of images to {job.limit} (out of {source_image_count})")
- images = images[: job.limit]
- image_count = len(images)
- job.progress.add_stage_param("collect", "Limit", image_count)
- else:
- image_count = source_image_count
+ job.progress.update_stage(
+ "collect",
+ status=JobState.STARTED,
+ progress=0,
+ )
- job.progress.update_stage(
- "collect",
- status=JobState.SUCCESS,
- progress=1,
+ images = list(
+ # @TODO return generator plus image count
+ # @TODO pass to celery group chain?
+ job.pipeline.collect_images(
+ collection=job.source_image_collection,
+ deployment=job.deployment,
+ source_images=[job.source_image_single] if job.source_image_single else None,
+ job_id=job.pk,
+ skip_processed=True,
+ # shuffle=job.shuffle,
)
+ )
+ source_image_count = len(images)
+ job.progress.update_stage("collect", total_images=source_image_count)
+
+ if job.shuffle and source_image_count > 1:
+ job.logger.info("Shuffling images")
+ random.shuffle(images)
+
+ if job.limit and source_image_count > job.limit:
+ job.logger.warn(f"Limiting number of images to {job.limit} (out of {source_image_count})")
+ images = images[: job.limit]
+ image_count = len(images)
+ job.progress.add_stage_param("collect", "Limit", image_count)
+ else:
+ image_count = source_image_count
- total_detections = 0
- total_classifications = 0
+ job.progress.update_stage(
+ "collect",
+ status=JobState.SUCCESS,
+ progress=1,
+ )
- CHUNK_SIZE = 2 # Keep it low to see more progress updates
- chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa
+ # End image collection stage
+ job.save()
- for i, chunk in enumerate(chunks):
- try:
- results = job.pipeline.process_images(
- images=chunk,
- job_id=job.pk,
- )
- except Exception as e:
- # Log error about image batch and continue
- job.logger.error(f"Failed to process image batch {i} of {len(chunks)}: {e}")
- continue
+ total_captures = 0
+ total_detections = 0
+ total_classifications = 0
+ CHUNK_SIZE = 4 # Keep it low to see more progress updates
+ chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa
+ request_failed_images = []
+
+ for i, chunk in enumerate(chunks):
+ request_sent = time.time()
+ job.logger.info(f"Processing image batch {i+1} of {len(chunks)}")
+ try:
+ results = job.pipeline.process_images(
+ images=chunk,
+ job_id=job.pk,
+ )
+ job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s")
+ except Exception as e:
+ # Log error about image batch and continue
+ job.logger.error(f"Failed to process image batch {i+1}: {e}")
+ request_failed_images.extend([img.pk for img in chunk])
+ else:
+ total_captures += len(results.source_images)
total_detections += len(results.detections)
total_classifications += len([c for d in results.detections for c in d.classifications])
- job.progress.update_stage(
- "process",
- status=JobState.STARTED,
- progress=(i + 1) / len(chunks),
- processed=(i + 1) * CHUNK_SIZE,
- remaining=image_count - (i + 1) * CHUNK_SIZE,
- detections=total_detections,
- classifications=total_classifications,
- )
- job.save()
if results.source_images or results.detections:
- save_results_task = job.pipeline.save_results_async(results=results, job_id=job.pk)
- job.logger.info(f"Saving results in sub-task {save_results_task.id}")
+ # @TODO add callback to report errors while saving results marking the job as failed
+ save_results_task: AsyncResult = job.pipeline.save_results_async(results=results, job_id=job.pk)
+ save_tasks.append((i + 1, save_results_task))
+ job.logger.info(f"Saving results for batch {i+1} in sub-task {save_results_task.id}")
job.progress.update_stage(
"process",
- status=JobState.SUCCESS,
+ status=JobState.STARTED,
+ progress=(i + 1) / len(chunks),
+ processed=min((i + 1) * CHUNK_SIZE, image_count),
+ failed=len(request_failed_images),
+ remaining=max(image_count - ((i + 1) * CHUNK_SIZE), 0),
)
+
+ # count the completed, successful, and failed save_tasks:
+ save_tasks_completed = [t for t in save_tasks if t[1].ready()]
+ failed_save_tasks = [t for t in save_tasks_completed if not t[1].successful()]
+
+ for failed_batch_num, failed_task in failed_save_tasks:
+ # First log all errors and update the job status. Then raise exception if any failed.
+ job.logger.error(f"Failed to save results from batch {failed_batch_num} (sub-task {failed_task.id})")
+
job.progress.update_stage(
"results",
- status=JobState.SUCCESS,
+ status=JobState.FAILURE if failed_save_tasks else JobState.STARTED,
+ progress=len(save_tasks_completed) / len(chunks),
+ captures=total_captures,
+ detections=total_detections,
+ classifications=total_classifications,
)
+ job.save()
+
+ # Stop processing if any save tasks have failed
+ # Otherwise, calculate the percent of images that have failed to save
+ throw_on_save_error = True
+ for failed_batch_num, failed_task in failed_save_tasks:
+ if throw_on_save_error:
+ failed_task.maybe_throw()
+
+ percent_successful = 1 - len(request_failed_images) / image_count if image_count else 0
+ job.logger.info(f"Processed {percent_successful:.0%} of images successfully.")
+
+ # Check all Celery sub-tasks if they have completed saving results
+ save_tasks_remaining = set(save_tasks) - set(save_tasks_completed)
+ job.logger.info(
+ f"Checking the status of {len(save_tasks_remaining)} remaining sub-tasks that are still saving results."
+ )
+ for batch_num, sub_task in save_tasks:
+ if not sub_task.ready():
+ job.logger.info(f"Waiting for batch {batch_num} to finish saving results (sub-task {sub_task.id})")
+ # @TODO this is not recommended! Use a group or chain. But we need to refactor.
+ # https://docs.celeryq.dev/en/latest/userguide/tasks.html#avoid-launching-synchronous-subtasks
+ sub_task.wait(disable_sync_subtasks=False, timeout=60)
+ if not sub_task.successful():
+ error: Exception = sub_task.result
+ job.logger.error(f"Failed to save results from batch {batch_num}! (sub-task {sub_task.id}): {error}")
+ sub_task.maybe_throw()
+
+ job.logger.info(f"All tasks completed for job {job.pk}")
+
+ FAILURE_THRESHOLD = 0.5
+ if image_count and (percent_successful < FAILURE_THRESHOLD):
+ job.progress.update_stage("process", status=JobState.FAILURE)
+ job.save()
+ raise Exception(f"Failed to process more than {int(FAILURE_THRESHOLD * 100)}% of images")
- job.update_status(JobState.SUCCESS)
- job.update_progress()
+ job.progress.update_stage(
+ "process",
+ status=JobState.SUCCESS,
+ progress=1,
+ )
+ job.progress.update_stage(
+ "results",
+ status=JobState.SUCCESS,
+ progress=1,
+ )
+ job.update_status(JobState.SUCCESS, save=False)
job.finished_at = datetime.datetime.now()
job.save()
@@ -445,7 +518,9 @@ def run(cls, job: "Job"):
job.finished_at = None
job.save()
- if job.deployment:
+ if not job.deployment:
+ raise ValueError("No deployment provided for data storage sync job")
+ else:
job.logger.info(f"Syncing captures for deployment {job.deployment}")
job.progress.update_stage(
cls.key,
@@ -465,10 +540,7 @@ def run(cls, job: "Job"):
)
job.update_status(JobState.SUCCESS)
job.save()
- else:
- job.update_status(JobState.FAILURE)
- job.update_progress()
job.finished_at = datetime.datetime.now()
job.save()
@@ -492,11 +564,7 @@ def run(cls, job: "Job"):
job.save()
if not job.source_image_collection:
- job.logger.error("No source image collection provided")
- job.update_status(JobState.FAILURE)
- job.finished_at = datetime.datetime.now()
- job.save()
- return
+ raise ValueError("No source image collection provided")
job.logger.info(f"Populating source image collection {job.source_image_collection}")
job.update_status(JobState.STARTED)
@@ -508,7 +576,7 @@ def run(cls, job: "Job"):
progress=0.10,
captures_added=0,
)
- job.update_progress(save=True)
+ job.save()
job.source_image_collection.populate_sample(job=job)
job.logger.info(f"Finished populating source image collection {job.source_image_collection}")
@@ -525,7 +593,6 @@ def run(cls, job: "Job"):
)
job.finished_at = datetime.datetime.now()
job.update_status(JobState.SUCCESS, save=False)
- job.update_progress(save=False)
job.save()
@@ -535,9 +602,7 @@ class UnknownJobType(JobType):
@classmethod
def run(cls, job: "Job"):
- job.logger.error(f"Unknown job type '{job.job_type()}'")
- job.update_status(JobState.UNKNOWN)
- job.save()
+ raise ValueError(f"Unknown job type '{job.job_type()}'")
VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType]
@@ -580,6 +645,7 @@ class Job(BaseModel):
# @TODO can we use an Enum or Pydantic model for status?
status = models.CharField(max_length=255, default=JobState.CREATED.name, choices=JobState.choices())
progress: JobProgress = SchemaField(JobProgress, default=default_job_progress())
+ logs: JobLogs = SchemaField(JobLogs, default=JobLogs())
result = models.JSONField(null=True, blank=True)
task_id = models.CharField(max_length=255, null=True, blank=True)
delay = models.IntegerField("Delay in seconds", default=0, help_text="Delay before running the job")
@@ -676,11 +742,12 @@ def setup(self, save=True):
pipeline_stage = self.progress.add_stage("Process")
self.progress.add_stage_param(pipeline_stage.key, "Processed", "")
self.progress.add_stage_param(pipeline_stage.key, "Remaining", "")
- self.progress.add_stage_param(pipeline_stage.key, "Detections", "")
- self.progress.add_stage_param(pipeline_stage.key, "Classifications", "")
+ self.progress.add_stage_param(pipeline_stage.key, "Failed", "")
saving_stage = self.progress.add_stage("Results")
- self.progress.add_stage_param(saving_stage.key, "Objects created", "")
+ self.progress.add_stage_param(saving_stage.key, "Captures", "")
+ self.progress.add_stage_param(saving_stage.key, "Detections", "")
+ self.progress.add_stage_param(saving_stage.key, "Classifications", "")
if save:
self.save()
@@ -700,6 +767,7 @@ def retry(self, async_task=True):
Retry the job.
"""
self.logger.info(f"Re-running job {self}")
+ self.finished_at = None
self.progress.reset()
self.status = JobState.RETRY
self.save()
@@ -733,11 +801,11 @@ def update_status(self, status=None, save=True):
status = task.status
if not status:
- self.logger.warn(f"Could not determine status of job {self.pk}")
+ self.logger.warning(f"Could not determine status of job {self.pk}")
return
if status != self.status:
- self.logger.info(f"Changing status of job {self.pk} to {status}")
+ self.logger.info(f"Changing status of job {self.pk} from {self.status} to {status}")
self.status = status
self.progress.summary.status = status
@@ -757,10 +825,10 @@ def update_progress(self, save=True):
if stage.progress > 0 and stage.status == JobState.CREATED:
# Update any stages that have started but are still in the CREATED state
stage.status = JobState.STARTED
- elif stage.status == JobState.SUCCESS and stage.progress < 1:
+ elif stage.status in JobState.final_states() and stage.progress < 1:
# Update any stages that are complete but have a progress less than 1
stage.progress = 1
- elif stage.progress == 1 and stage.status == JobState.STARTED:
+ elif stage.progress == 1 and stage.status not in JobState.final_states():
# Update any stages that are complete but are still in the STARTED state
stage.status = JobState.SUCCESS
total_progress = sum([stage.progress for stage in self.progress.stages]) / len(self.progress.stages)
@@ -768,18 +836,18 @@ def update_progress(self, save=True):
self.progress.summary.progress = total_progress
if save:
- self.save()
+ self.save(update_progress=False)
def duration(self) -> datetime.timedelta | None:
if self.started_at and self.finished_at:
return self.finished_at - self.started_at
return None
- def save(self, *args, **kwargs):
+ def save(self, update_progress=True, *args, **kwargs):
"""
Create the job stages if they don't exist.
"""
- if self.pk and self.progress.stages:
+ if self.pk and self.progress.stages and update_progress:
self.update_progress(save=False)
else:
self.setup(save=False)
diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py
index ab07ef320..f1fef491d 100644
--- a/ami/jobs/serializers.py
+++ b/ami/jobs/serializers.py
@@ -11,7 +11,7 @@
from ami.ml.models import Pipeline
from ami.ml.serializers import PipelineNestedSerializer
-from .models import Job, JobProgress
+from .models import Job, JobLogs, JobProgress, MLJob
class JobProjectNestedSerializer(DefaultSerializer):
@@ -37,7 +37,11 @@ class JobListSerializer(DefaultSerializer):
source_image_collection = SourceImageCollectionNestedSerializer(read_only=True)
source_image_single = SourceImageNestedSerializer(read_only=True)
progress = SchemaField(schema=JobProgress, read_only=True)
+ logs = SchemaField(schema=JobLogs, read_only=True)
job_type = JobTypeSerializer(read_only=True)
+ # All jobs created from the Jobs UI are ML jobs (datasync, etc. are created for the user)
+ # @TODO Remove this when the UI is updated pass a job type. This should be a required field.
+ job_type_key = serializers.SlugField(write_only=True, default=MLJob.key)
project_id = serializers.PrimaryKeyRelatedField(
label="Project",
@@ -109,7 +113,9 @@ class Meta:
"finished_at",
"duration",
"progress",
+ "logs",
"job_type",
+ "job_type_key",
# "duration",
# "duration_label",
# "progress_label",
diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py
index 4c9d41447..59d3cfd59 100644
--- a/ami/jobs/tests.py
+++ b/ami/jobs/tests.py
@@ -100,8 +100,10 @@ def test_create_job_unauthenticated(self):
jobs_create_url = reverse_with_params("api:job-list")
job_data = {
"project_id": self.project.pk,
+ "source_image_collection_id": self.source_image_collection.pk,
"name": "Test job unauthenticated",
"delay": 0,
+ "job_type_key": SourceImageCollectionPopulateJob.key,
}
self.client.force_authenticate(user=None)
resp = self.client.post(jobs_create_url, job_data)
@@ -113,9 +115,10 @@ def _create_job(self, name: str, start_now: bool = True):
job_data = {
"project_id": self.job.project.pk,
"name": name,
- "collection_id": self.source_image_collection.pk,
+ "source_image_collection_id": self.source_image_collection.pk,
"delay": 0,
"start_now": start_now,
+ "job_type_key": SourceImageCollectionPopulateJob.key,
}
resp = self.client.post(jobs_create_url, job_data)
self.client.force_authenticate(user=None)
diff --git a/ami/jobs/views.py b/ami/jobs/views.py
index e783ff9a5..49a87afc0 100644
--- a/ami/jobs/views.py
+++ b/ami/jobs/views.py
@@ -11,7 +11,7 @@
from ami.utils.fields import url_boolean_param
from ami.utils.requests import get_active_project, project_id_doc_param
-from .models import Job, JobState, MLJob
+from .models import Job, JobState
from .serializers import JobListSerializer, JobSerializer
logger = logging.getLogger(__name__)
@@ -117,11 +117,6 @@ def perform_create(self, serializer):
If the ``start_now`` parameter is passed, enqueue the job immediately.
"""
- # All jobs created from the Jobs UI are ML jobs.
- # @TODO Remove this when the UI is updated pass a job type
- if not serializer.validated_data.get("job_type_key"):
- serializer.validated_data["job_type_key"] = MLJob.key
-
job: Job = serializer.save() # type: ignore
if url_boolean_param(self.request, "start_now", default=False):
# job.run()
diff --git a/ami/main/admin.py b/ami/main/admin.py
index 19184a6fe..9429bf45c 100644
--- a/ami/main/admin.py
+++ b/ami/main/admin.py
@@ -13,6 +13,7 @@
from .models import (
BlogPost,
+ Classification,
Deployment,
Device,
Event,
@@ -236,11 +237,80 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
class OccurrenceAdmin(admin.ModelAdmin[Occurrence]):
"""Admin panel example for ``Occurrence`` model."""
- list_display = ("id", "determination", "project", "deployment", "event")
+ list_display = (
+ "id",
+ "determination",
+ "project",
+ "deployment",
+ "event",
+ "detections_count",
+ "created_at",
+ "updated_at",
+ )
+
+ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
+ qs = super().get_queryset(request)
+ qs = qs.select_related("determination", "project", "deployment", "event")
+ # Add detections count to queryset
+ qs = qs.annotate(detections_count=models.Count("detections"))
+ # Add min, max and avg detection__classifications counts to queryset
+ # qs = qs.annotate(
+ # min_detection_classifications=models.Min("detections__classifications"),
+ # max_detection_classifications=models.Max("detections__classifications"),
+ # avg_detection_classifications=models.Avg("detections__classifications"),
+ # )
+ return qs
+
+ @admin.display(
+ description="Detections",
+ ordering="detections_count",
+ )
+ def detections_count(self, obj) -> int:
+ return obj.detections_count
+
+ ordering = ("-created_at",)
+
+
+@admin.register(Classification)
+class ClassificationAdmin(admin.ModelAdmin[Classification]):
+ list_display = (
+ "__str__",
+ "taxon",
+ "algorithm",
+ "num_scores",
+ "num_logits",
+ "detection_date",
+ "timestamp",
+ "terminal",
+ "created_at",
+ )
+
+ list_filter = (
+ "algorithm",
+ "terminal",
+ "created_at",
+ "detection__source_image__project",
+ "taxon__rank",
+ )
def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
- return qs.select_related("determination", "project", "deployment", "event")
+ return qs.select_related(
+ "taxon", "detection", "detection__source_image", "detection__source_image__project"
+ ).annotate(
+ detection_date=models.F("detection__timestamp"),
+ )
+
+ @admin.display()
+ def detection_date(self, obj: Classification) -> str:
+ # This property comes from the annotation in get_queryset, not the model
+ return obj.detection_date # type: ignore
+
+ def num_scores(self, obj: Classification) -> int:
+ return len(obj.scores) if obj.scores else 0
+
+ def num_logits(self, obj: Classification) -> int:
+ return len(obj.logits) if obj.logits else 0
class TaxonParentFilter(admin.SimpleListFilter):
diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py
index b37221aa2..9ba46c27c 100644
--- a/ami/main/api/serializers.py
+++ b/ami/main/api/serializers.py
@@ -451,6 +451,7 @@ class Meta:
"name",
"rank",
"details",
+ "gbif_taxon_key",
]
@@ -720,6 +721,7 @@ class Meta:
"detections_count",
"events_count",
"occurrences",
+ "gbif_taxon_key",
]
@@ -738,10 +740,50 @@ class Meta:
]
+class ClassificationPredictionItemSerializer(serializers.Serializer):
+ taxon = TaxonNestedSerializer(read_only=True)
+ score = serializers.FloatField(read_only=True)
+ logit = serializers.FloatField(read_only=True)
+
+
class ClassificationSerializer(DefaultSerializer):
taxon = TaxonNestedSerializer(read_only=True)
algorithm = AlgorithmSerializer(read_only=True)
+ top_n = ClassificationPredictionItemSerializer(many=True, read_only=True)
+
+ class Meta:
+ model = Classification
+ fields = [
+ "id",
+ "details",
+ "taxon",
+ "score",
+ "algorithm",
+ "scores",
+ "logits",
+ "top_n",
+ "created_at",
+ "updated_at",
+ ]
+
+
+class ClassificationWithTaxaSerializer(ClassificationSerializer):
+ """
+ Return all possible taxa objects in the category map with the classification.
+
+ This is slow for large category maps.
+ It's recommended to retrieve and cache the category map with taxa ahead of time.
+ """
+ taxa = TaxonNestedSerializer(many=True, read_only=True)
+
+ class Meta(ClassificationSerializer.Meta):
+ fields = ClassificationSerializer.Meta.fields + [
+ "taxa",
+ ]
+
+
+class ClassificationListSerializer(DefaultSerializer):
class Meta:
model = Classification
fields = [
@@ -751,11 +793,22 @@ class Meta:
"score",
"algorithm",
"created_at",
+ "updated_at",
]
-class OccurrenceClassificationSerializer(ClassificationSerializer):
- pass
+class ClassificationNestedSerializer(ClassificationSerializer):
+ class Meta:
+ model = Classification
+ fields = [
+ "id",
+ "details",
+ "taxon",
+ "score",
+ "terminal",
+ "algorithm",
+ "created_at",
+ ]
class CaptureDetectionsSerializer(DefaultSerializer):
@@ -767,6 +820,7 @@ class Meta:
# queryset = Detection.objects.prefetch_related("classifications")
fields = [
"id",
+ "details",
"url",
"width",
"height",
@@ -800,7 +854,7 @@ class Meta:
class DetectionNestedSerializer(DefaultSerializer):
- classifications = ClassificationSerializer(many=True, read_only=True)
+ classifications = ClassificationNestedSerializer(many=True, read_only=True)
capture = DetectionCaptureNestedSerializer(read_only=True, source="source_image")
class Meta:
@@ -808,6 +862,7 @@ class Meta:
# queryset = Detection.objects.prefetch_related("classifications")
fields = [
"id",
+ "details",
"timestamp",
"url",
"capture",
@@ -842,6 +897,7 @@ class DetectionSerializer(DefaultSerializer):
detection_algorithm_id = serializers.PrimaryKeyRelatedField(
queryset=Algorithm.objects.all(), source="detection_algorithm", write_only=True
)
+ classifications = ClassificationNestedSerializer(many=True, read_only=True)
class Meta:
model = Detection
@@ -849,6 +905,7 @@ class Meta:
"source_image",
"detection_algorithm",
"detection_algorithm_id",
+ "classifications",
]
@@ -1077,6 +1134,7 @@ class Meta:
"determination_details",
"identifications",
"created_at",
+ "updated_at",
]
def get_determination_details(self, obj: Occurrence):
@@ -1099,7 +1157,7 @@ def get_determination_details(self, obj: Occurrence):
if identification or not obj.best_prediction:
prediction = None
else:
- prediction = OccurrenceClassificationSerializer(obj.best_prediction, context=context).data
+ prediction = ClassificationNestedSerializer(obj.best_prediction, context=context).data
return dict(
taxon=taxon,
@@ -1112,7 +1170,8 @@ def get_determination_details(self, obj: Occurrence):
class OccurrenceSerializer(OccurrenceListSerializer):
determination = CaptureTaxonSerializer(read_only=True)
detections = DetectionNestedSerializer(many=True, read_only=True)
- predictions = OccurrenceClassificationSerializer(many=True, read_only=True)
+ identifications = OccurrenceIdentificationSerializer(many=True, read_only=True)
+ predictions = ClassificationNestedSerializer(many=True, read_only=True)
deployment = DeploymentNestedSerializer(read_only=True)
event = EventNestedSerializer(read_only=True)
# first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
diff --git a/ami/main/api/views.py b/ami/main/api/views.py
index 724eebee9..74fd8860e 100644
--- a/ami/main/api/views.py
+++ b/ami/main/api/views.py
@@ -49,7 +49,9 @@
update_detection_counts,
)
from .serializers import (
+ ClassificationListSerializer,
ClassificationSerializer,
+ ClassificationWithTaxaSerializer,
DeploymentListSerializer,
DeploymentSerializer,
DetectionListSerializer,
@@ -429,7 +431,13 @@ class SourceImageViewSet(DefaultViewSet):
queryset = SourceImage.objects.all()
serializer_class = SourceImageSerializer
- filterset_fields = ["event", "deployment", "deployment__project", "collections"]
+ filterset_fields = [
+ "event",
+ "deployment",
+ "deployment__project",
+ "collections",
+ "project",
+ ]
ordering_fields = [
"created_at",
"updated_at",
@@ -732,9 +740,9 @@ class DetectionViewSet(DefaultViewSet):
API endpoint that allows detections to be viewed or edited.
"""
- queryset = Detection.objects.all()
+ queryset = Detection.objects.all().select_related("source_image", "detection_algorithm")
serializer_class = DetectionSerializer
- filterset_fields = ["source_image", "detection_algorithm"]
+ filterset_fields = ["source_image", "detection_algorithm", "source_image__project"]
ordering_fields = ["created_at", "updated_at", "detection_score", "timestamp"]
def get_serializer_class(self):
@@ -1262,15 +1270,34 @@ class ClassificationViewSet(DefaultViewSet):
API endpoint for viewing and adding classification results from a model.
"""
- queryset = Classification.objects.all()
+ queryset = Classification.objects.all() # .select_related("taxon", "algorithm", "detection")
serializer_class = ClassificationSerializer
- filterset_fields = ["detection", "detection__occurrence", "taxon", "algorithm"]
+ filterset_fields = [
+ "detection",
+ "detection__occurrence",
+ "taxon",
+ "algorithm",
+ "detection__source_image",
+ "detection__source_image__project",
+ ]
ordering_fields = [
"created_at",
"updated_at",
"score",
]
+ def get_serializer_class(self):
+ """
+ Return a different serializer for list and detail views.
+ If "with_taxa" is in the query params, return a different serializer.
+ """
+ if self.action == "list":
+ return ClassificationListSerializer
+ elif "with_taxa" in self.request.query_params:
+ return ClassificationWithTaxaSerializer
+ else:
+ return ClassificationSerializer
+
class SummaryView(GenericAPIView):
permission_classes = [IsActiveStaffOrReadOnly]
diff --git a/ami/main/charts.py b/ami/main/charts.py
index 352e2e2ac..4d9809817 100644
--- a/ami/main/charts.py
+++ b/ami/main/charts.py
@@ -199,7 +199,7 @@ def detections_per_hour(project_pk: int):
def occurrences_accumulated(project_pk: int):
- # Line chart of the accumulated number of occurrnces over time throughout the season
+ # Line chart of the accumulated number of occurrences over time throughout the season
Occurrence = apps.get_model("main", "Occurrence")
occurrences_per_day = (
@@ -212,7 +212,8 @@ def occurrences_accumulated(project_pk: int):
.order_by("event__start")
)
- if occurrences_per_day.count():
+ occurrences_exist = Occurrence.objects.filter(project=project_pk).exists()
+ if occurrences_exist:
days, counts = list(zip(*occurrences_per_day))
# Accumulate the counts
counts = list(itertools.accumulate(counts))
diff --git a/ami/main/migrations/0039_remove_classification_raw_output_and_more.py b/ami/main/migrations/0039_remove_classification_raw_output_and_more.py
new file mode 100644
index 000000000..50d8cf7d0
--- /dev/null
+++ b/ami/main/migrations/0039_remove_classification_raw_output_and_more.py
@@ -0,0 +1,54 @@
+# Generated by Django 4.2.10 on 2024-11-26 18:52
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0007_algorithmcategorymap_algorithm_category_map"),
+ ("main", "0038_alter_detection_path_alter_sourceimage_event_and_more"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="classification",
+ name="raw_output",
+ ),
+ migrations.RemoveField(
+ model_name="classification",
+ name="softmax_output",
+ ),
+ migrations.AddField(
+ model_name="classification",
+ name="category_map",
+ field=models.ForeignKey(
+ null=True, on_delete=django.db.models.deletion.PROTECT, to="ml.algorithmcategorymap"
+ ),
+ ),
+ migrations.AddField(
+ model_name="classification",
+ name="logits",
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.FloatField(), null=True, size=None),
+ ),
+ migrations.AddField(
+ model_name="classification",
+ name="scores",
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.FloatField(), null=True, size=None),
+ ),
+ migrations.AddField(
+ model_name="classification",
+ name="terminal",
+ field=models.BooleanField(
+ default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?"
+ ),
+ ),
+ migrations.AddField(
+ model_name="taxon",
+ name="search_names",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.CharField(max_length=255), blank=True, null=True, size=None
+ ),
+ ),
+ ]
diff --git a/ami/main/migrations/0040_alter_classification_logits_and_more.py b/ami/main/migrations/0040_alter_classification_logits_and_more.py
new file mode 100644
index 000000000..8a2aff79f
--- /dev/null
+++ b/ami/main/migrations/0040_alter_classification_logits_and_more.py
@@ -0,0 +1,33 @@
+# Generated by Django 4.2.10 on 2024-12-05 01:49
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("main", "0039_remove_classification_raw_output_and_more"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="classification",
+ name="logits",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.FloatField(),
+ help_text="The raw output of the last fully connected layer of the model",
+ null=True,
+ size=None,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="classification",
+ name="scores",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.FloatField(),
+ help_text="The probabilities the model, calibrated by the model maker, likely the softmax output",
+ null=True,
+ size=None,
+ ),
+ ),
+ ]
diff --git a/ami/main/migrations/0044_merge_20250124_2333.py b/ami/main/migrations/0044_merge_20250124_2333.py
new file mode 100644
index 000000000..1cae6da87
--- /dev/null
+++ b/ami/main/migrations/0044_merge_20250124_2333.py
@@ -0,0 +1,12 @@
+# Generated by Django 4.2.10 on 2025-01-24 23:33
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("main", "0039_project_users_squashed_0043_rename_users_project_members"),
+ ("main", "0040_alter_classification_logits_and_more"),
+ ]
+
+ operations = []
diff --git a/ami/main/models.py b/ami/main/models.py
index a77b39335..41f887ebf 100644
--- a/ami/main/models.py
+++ b/ami/main/models.py
@@ -12,6 +12,7 @@
import pydantic
from django.apps import apps
from django.conf import settings
+from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import ValidationError
from django.core.files.storage import default_storage
from django.db import IntegrityError, models
@@ -153,7 +154,7 @@ def summary_data(self):
plots.append(charts.captures_per_hour(project_pk=self.pk))
if self.occurrences.exists():
plots.append(charts.detections_per_hour(project_pk=self.pk))
- plots.append(charts.occurrences_accumulated(project_pk=self.pk))
+ # plots.append(charts.occurrences_accumulated(project_pk=self.pk))
else:
plots.append(charts.events_per_month(project_pk=self.pk))
# plots.append(charts.captures_per_month(project_pk=self.pk))
@@ -1248,6 +1249,7 @@ class SourceImage(BaseModel):
blank=True,
)
+ event_id: int | None
detections: models.QuerySet["Detection"]
collections: models.QuerySet["SourceImageCollection"]
jobs: models.QuerySet["Job"]
@@ -1782,27 +1784,27 @@ class Classification(BaseModel):
related_name="classifications",
)
- # occurrence = models.ForeignKey(
- # "Occurrence",
- # on_delete=models.SET_NULL,
- # null=True,
- # related_name="predictions",
- # )
-
taxon = models.ForeignKey("Taxon", on_delete=models.SET_NULL, null=True, related_name="classifications")
score = models.FloatField(null=True)
timestamp = models.DateTimeField()
- # terminal = models.BooleanField(
- # default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?"
- # )
-
- softmax_output = models.JSONField(null=True) # scores for all classes
- raw_output = models.JSONField(null=True) # raw output from the model
+ terminal = models.BooleanField(
+ default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?"
+ )
+ logits = ArrayField(
+ models.FloatField(), null=True, help_text="The raw output of the last fully connected layer of the model"
+ )
+ scores = ArrayField(
+ models.FloatField(),
+ null=True,
+ help_text="The probabilities the model, calibrated by the model maker, likely the softmax output",
+ )
+ category_map = models.ForeignKey("ml.AlgorithmCategoryMap", on_delete=models.PROTECT, null=True)
algorithm = models.ForeignKey(
"ml.Algorithm",
on_delete=models.SET_NULL,
null=True,
+ related_name="classifications",
)
# job = models.CharField(max_length=255, null=True)
@@ -1816,7 +1818,91 @@ class Meta:
ordering = ["-created_at", "-score"]
def __str__(self) -> str:
- return f"#{self.pk} to Taxon #{self.taxon_id} ({self.score:.2f}) by Algorithm #{self.algorithm_id}"
+ terminal = "Terminal" if self.terminal else "Intermediate"
+ if logger.getEffectiveLevel() == logging.DEBUG:
+ # Query the related objects to get the names
+ return f"#{self.pk} to Taxon {self.taxon} ({self.score:.2f}) by Algorithm {self.algorithm} ({terminal})"
+ return (
+ f"#{self.pk} to Taxon #{self.taxon_id} ({self.score:.2f}) by Algorithm #{self.algorithm_id} ({terminal})"
+ )
+
+ def top_scores_with_index(self, n: int | None = None) -> typing.Iterable[tuple[int, float]]:
+ """
+ Return the scores with their index, but sorted by score.
+ """
+ if self.scores:
+ top_scores_by_index = sorted(enumerate(self.scores), key=lambda x: x[1], reverse=True)[:n]
+ return top_scores_by_index
+ else:
+ return []
+
+ def predictions(self, sort=True) -> typing.Iterable[tuple[str, float]]:
+ """
+ Return all label-score pairs for this classification using the category map.
+ """
+ if not self.category_map:
+ raise ValueError("Classification must have a category map to get predictions.")
+ scores = self.scores or []
+ preds = zip(self.category_map.labels, scores)
+ if sort:
+ return sorted(preds, key=lambda x: x[1], reverse=True)
+ else:
+ return preds
+
+ def predictions_with_taxa(self, sort=True) -> typing.Iterable[tuple["Taxon", float]]:
+ """
+ Return taxa objects and their scores for this classification using the category map.
+
+ @TODO make this more efficient with numpy and/or postgres array functions. especially when we only need
+ the top N out of thousands of taxa.
+ """
+ if not self.category_map:
+ raise ValueError("Classification must have a category map to get predictions.")
+ scores = self.scores or []
+ category_data_with_taxa = self.category_map.with_taxa()
+ taxa_sorted_by_index = [cat["taxon"] for cat in sorted(category_data_with_taxa, key=lambda cat: cat["index"])]
+ preds = zip(taxa_sorted_by_index, scores)
+ if sort:
+ return sorted(preds, key=lambda x: x[1], reverse=True)
+ else:
+ return preds
+
+ def taxa(self) -> typing.Iterable["Taxon"]:
+ """
+ Return the taxa objects for this classification using the category map.
+ """
+ if not self.category_map:
+ return []
+ category_data_with_taxa = self.category_map.with_taxa()
+ taxa_sorted_by_index = [cat["taxon"] for cat in sorted(category_data_with_taxa, key=lambda cat: cat["index"])]
+ return taxa_sorted_by_index
+
+ def top_n(self, n: int = 3) -> list[dict[str, "Taxon | float | None"]]:
+ """Return top N taxa and scores for this classification."""
+ if not self.category_map:
+ raise ValueError("Classification must have a category map to get top N.")
+
+ top_scored = self.top_scores_with_index(n) # (index, score) pairs
+ indexes = [idx for idx, _ in top_scored]
+ category_data = self.category_map.with_taxa(only_indexes=indexes)
+ index_to_taxon = {cat["index"]: cat["taxon"] for cat in category_data}
+
+ return [
+ {
+ "taxon": index_to_taxon[i],
+ "score": s,
+ "logit": self.logits[i] if self.logits else None,
+ }
+ for i, s in top_scored
+ ]
+
+ def save(self, *args, **kwargs):
+ """
+ Set the category map based on the algorithm.
+ """
+ if self.algorithm and not self.category_map:
+ self.category_map = self.algorithm.category_map
+ super().save(*args, **kwargs)
@final
@@ -2107,10 +2193,22 @@ def best_detection(self):
@functools.cached_property
def best_prediction(self):
- return self.predictions().first()
+ """
+ Use the best prediction as the best identification if there are no human identifications.
+
+ Uses the highest scoring classification (from any algorithm) as the best prediction.
+ Considers terminal classifications first, then non-terminal ones.
+ (Terminal classifications are the final classifications of a pipeline, non-terminal are intermediate models.)
+ """
+ return self.predictions().order_by("-terminal", "-score").first()
@functools.cached_property
def best_identification(self):
+ """
+ The most recent human identification is used as the best identification.
+
+ @TODO this could use a confidence level chosen manually by the users/experts.
+ """
return Identification.objects.filter(occurrence=self, withdrawn=False).order_by("-created_at").first()
def get_determination_score(self) -> float | None:
@@ -2163,6 +2261,7 @@ def save(self, update_determination=True, *args, **kwargs):
if self.determination and not self.determination_score:
# This may happen for legacy occurrences that were created
# before the determination_score field was added
+ # @TODO remove
self.determination_score = self.get_determination_score()
if not self.determination_score:
logger.warning(f"Could not determine score for {self}")
@@ -2175,16 +2274,16 @@ class Meta:
def update_occurrence_determination(
occurrence: Occurrence, current_determination: typing.Optional["Taxon"] = None, save=True
-):
+) -> bool:
"""
Update the determination of the occurrence based on the identifications & predictions.
If there are identifications, set the determination to the latest identification.
If there are no identifications, set the determination to the top prediction.
- The `current_determination` is the determination curently saved in the database.
+ The `current_determination` is the determination currently saved in the database.
The `occurrence` object may already have a different un-saved determination set
- so it is neccessary to retrieve the current determination from the database, but
+ so it is necessary to retrieve the current determination from the database, but
this can also be passed in as an argument to avoid an extra database query.
@TODO Add tests for this important method!
@@ -2196,6 +2295,8 @@ def update_occurrence_determination(
del occurrence.best_identification
if hasattr(occurrence, "best_prediction"):
del occurrence.best_prediction
+ if hasattr(occurrence, "best_identification"):
+ del occurrence.best_identification
current_determination = (
current_determination
@@ -2217,18 +2318,29 @@ def update_occurrence_determination(
new_score = top_prediction.score
if new_determination and new_determination != current_determination:
- logger.info(f"Changing det. of {occurrence} from {current_determination} to {new_determination}")
+ logger.debug(f"Changing det. of {occurrence} from {current_determination} to {new_determination}")
occurrence.determination = new_determination
needs_update = True
if new_score and new_score != occurrence.determination_score:
- logger.info(f"Changing det. score of {occurrence} from {occurrence.determination_score} to {new_score}")
+ logger.debug(f"Changing det. score of {occurrence} from {occurrence.determination_score} to {new_score}")
occurrence.determination_score = new_score
needs_update = True
+ if not needs_update:
+ if logger.getEffectiveLevel() <= logging.DEBUG:
+ all_predictions = occurrence.predictions()
+ all_preds_print = ", ".join([str(p) for p in all_predictions])
+ logger.debug(
+ f"No update needed for determination of {occurrence}. Best prediction: {occurrence.best_prediction}. "
+ f"All preds: {all_preds_print}"
+ )
+
if save and needs_update:
occurrence.save(update_determination=False)
+ return needs_update
+
@final
class TaxaManager(models.Manager):
@@ -2457,6 +2569,7 @@ class Taxon(BaseModel):
active = models.BooleanField(default=True)
synonym_of = models.ForeignKey("self", on_delete=models.SET_NULL, null=True, blank=True, related_name="synonyms")
+ search_names = ArrayField(models.CharField(max_length=255), null=True, blank=True)
gbif_taxon_key = models.BigIntegerField("GBIF taxon key", blank=True, null=True)
bold_taxon_bin = models.CharField("BOLD taxon BIN", max_length=255, blank=True, null=True)
inat_taxon_id = models.BigIntegerField("iNaturalist taxon ID", blank=True, null=True)
diff --git a/ami/ml/admin.py b/ami/ml/admin.py
index 3f4784d1b..54d6f0185 100644
--- a/ami/ml/admin.py
+++ b/ami/ml/admin.py
@@ -2,7 +2,7 @@
from ami.main.admin import AdminBase
-from .models.algorithm import Algorithm
+from .models.algorithm import Algorithm, AlgorithmCategoryMap
from .models.pipeline import Pipeline
from .models.processing_service import ProcessingService
@@ -11,8 +11,10 @@
class AlgorithmAdmin(AdminBase):
list_display = [
"name",
+ "key",
"version",
"version_name",
+ "task_type",
"created_at",
"updated_at",
]
@@ -26,6 +28,7 @@ class AlgorithmAdmin(AdminBase):
]
list_filter = [
"pipelines",
+ "task_type",
]
@@ -68,3 +71,33 @@ class ProcessingServiceAdmin(AdminBase):
"endpoint_url",
"created_at",
]
+
+
+@admin.register(AlgorithmCategoryMap)
+class AlgorithmCategoryMapAdmin(AdminBase):
+ list_display = [
+ "version",
+ "uri",
+ "created_at",
+ "num_data_items",
+ "num_labels",
+ ]
+ search_fields = [
+ "version",
+ ]
+ ordering = [
+ "version",
+ ]
+ list_filter = [
+ "algorithms",
+ ]
+ formfield_overrides = {
+ # See https://pypi.org/project/django-json-widget/
+ # models.JSONField: {"widget": JSONInput},
+ }
+
+ def num_data_items(self, obj):
+ return len(obj.data) if obj.data else 0
+
+ def num_labels(self, obj):
+ return len(obj.labels) if obj.labels else 0
diff --git a/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py b/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py
new file mode 100644
index 000000000..6dd501558
--- /dev/null
+++ b/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py
@@ -0,0 +1,46 @@
+# Generated by Django 4.2.10 on 2024-11-26 18:52
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0006_alter_pipeline_endpoint_url_alter_pipeline_projects"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="AlgorithmCategoryMap",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("data", models.JSONField()),
+ (
+ "labels",
+ django.contrib.postgres.fields.ArrayField(
+ base_field=models.CharField(max_length=255), default=list, size=None
+ ),
+ ),
+ ("version", models.CharField(blank=True, max_length=255, null=True)),
+ ("url", models.URLField(blank=True, null=True)),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ migrations.AddField(
+ model_name="algorithm",
+ name="category_map",
+ field=models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="algorithms",
+ to="ml.algorithmcategorymap",
+ ),
+ ),
+ ]
diff --git a/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py b/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py
new file mode 100644
index 000000000..a51bbb5cc
--- /dev/null
+++ b/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py
@@ -0,0 +1,35 @@
+# Generated by Django 4.2.10 on 2024-12-05 00:24
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0007_algorithmcategorymap_algorithm_category_map"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="algorithmcategorymap",
+ name="description",
+ field=models.TextField(blank=True, null=True),
+ ),
+ migrations.AlterField(
+ model_name="algorithmcategorymap",
+ name="data",
+ field=models.JSONField(
+ help_text="Complete metadata for each label, such as id, gbif_key, explicit index, source, etc."
+ ),
+ ),
+ migrations.AlterField(
+ model_name="algorithmcategorymap",
+ name="labels",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.CharField(max_length=255),
+ default=list,
+ help_text="A simple list of string labels in the correct index order used by the model.",
+ size=None,
+ ),
+ ),
+ ]
diff --git a/ami/ml/migrations/0009_algorithm_task_type.py b/ami/ml/migrations/0009_algorithm_task_type.py
new file mode 100644
index 000000000..6bdfda3b3
--- /dev/null
+++ b/ami/ml/migrations/0009_algorithm_task_type.py
@@ -0,0 +1,38 @@
+# Generated by Django 4.2.10 on 2024-12-05 01:23
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0008_algorithmcategorymap_description_and_more"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="algorithm",
+ name="task_type",
+ field=models.CharField(
+ blank=True,
+ choices=[
+ ("detection", "Detection"),
+ ("segmentation", "Segmentation"),
+ ("classification", "Classification"),
+ ("embedding", "Embedding"),
+ ("tracking", "Tracking"),
+ ("tagging", "Tagging"),
+ ("regression", "Regression"),
+ ("captioning", "Captioning"),
+ ("generation", "Generation"),
+ ("translation", "Translation"),
+ ("summarization", "Summarization"),
+ ("question_answering", "Question Answering"),
+ ("depth_estimation", "Depth Estimation"),
+ ("pose_estimation", "Pose Estimation"),
+ ("size_estimation", "Size Estimation"),
+ ("other", "Other"),
+ ],
+ max_length=255,
+ ),
+ ),
+ ]
diff --git a/ami/ml/migrations/0010_alter_algorithm_version.py b/ami/ml/migrations/0010_alter_algorithm_version.py
new file mode 100644
index 000000000..84d42928e
--- /dev/null
+++ b/ami/ml/migrations/0010_alter_algorithm_version.py
@@ -0,0 +1,19 @@
+# Generated by Django 4.2.10 on 2024-12-05 01:49
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0009_algorithm_task_type"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="algorithm",
+ name="version",
+ field=models.IntegerField(
+ default=1, help_text="An internal, sortable and incrementable version number for the model."
+ ),
+ ),
+ ]
diff --git a/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py b/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py
new file mode 100644
index 000000000..9f084ca80
--- /dev/null
+++ b/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py
@@ -0,0 +1,50 @@
+# Generated by Django 4.2.10 on 2024-12-05 04:08
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0010_alter_algorithm_version"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="algorithm",
+ name="task_type",
+ field=models.CharField(
+ choices=[
+ ("detection", "Detection"),
+ ("segmentation", "Segmentation"),
+ ("classification", "Classification"),
+ ("embedding", "Embedding"),
+ ("tracking", "Tracking"),
+ ("tagging", "Tagging"),
+ ("regression", "Regression"),
+ ("captioning", "Captioning"),
+ ("generation", "Generation"),
+ ("translation", "Translation"),
+ ("summarization", "Summarization"),
+ ("question_answering", "Question Answering"),
+ ("depth_estimation", "Depth Estimation"),
+ ("pose_estimation", "Pose Estimation"),
+ ("size_estimation", "Size Estimation"),
+ ("other", "Other"),
+ ("unknown", "Unknown"),
+ ],
+ default="unknown",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="algorithm",
+ name="url",
+ field=models.URLField(blank=True, null=True),
+ ),
+ migrations.AlterField(
+ model_name="algorithm",
+ name="version_name",
+ field=models.CharField(blank=True, max_length=255, null=True),
+ ),
+ ]
diff --git a/ami/ml/migrations/0012_alter_algorithm_unique_together.py b/ami/ml/migrations/0012_alter_algorithm_unique_together.py
new file mode 100644
index 000000000..1ccf7fffd
--- /dev/null
+++ b/ami/ml/migrations/0012_alter_algorithm_unique_together.py
@@ -0,0 +1,16 @@
+# Generated by Django 4.2.10 on 2024-12-05 04:20
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0011_alter_algorithm_task_type_alter_algorithm_url_and_more"),
+ ]
+
+ operations = [
+ migrations.AlterUniqueTogether(
+ name="algorithm",
+ unique_together=set(),
+ ),
+ ]
diff --git a/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py b/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py
new file mode 100644
index 000000000..36454c7dd
--- /dev/null
+++ b/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py
@@ -0,0 +1,54 @@
+# Generated by Django 4.2.10 on 2024-12-06 21:02
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0012_alter_algorithm_unique_together"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="algorithm",
+ name="url",
+ ),
+ migrations.RemoveField(
+ model_name="algorithmcategorymap",
+ name="url",
+ ),
+ migrations.AddField(
+ model_name="algorithm",
+ name="uri",
+ field=models.CharField(
+ blank=True,
+ help_text="A URI to the weights or model details. Could be a public web URL or object store path.",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AddField(
+ model_name="algorithmcategorymap",
+ name="labels_hash",
+ field=models.BigIntegerField(
+ help_text="A hash of the labels for faster comparison of label sets. Created on save.", null=True
+ ),
+ ),
+ migrations.AddField(
+ model_name="algorithmcategorymap",
+ name="uri",
+ field=models.CharField(
+ blank=True,
+ help_text="A URI to the category map file. Could be a public web URL or object store path.",
+ max_length=255,
+ null=True,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="algorithmcategorymap",
+ name="data",
+ field=models.JSONField(
+ help_text="Complete metadata for each label, such as id, gbif_key, lookup value, source, etc."
+ ),
+ ),
+ ]
diff --git a/ami/ml/migrations/0014_rename_model_keys.py b/ami/ml/migrations/0014_rename_model_keys.py
new file mode 100644
index 000000000..1ddb104f3
--- /dev/null
+++ b/ami/ml/migrations/0014_rename_model_keys.py
@@ -0,0 +1,39 @@
+# Generated by Django 4.2.10 on 2024-12-06 21:15
+
+import random
+import string
+import logging
+from django.db import migrations
+
+logger = logging.getLogger(__name__)
+
+
+def rename_algorithm_keys(apps, schema_editor):
+ """
+ The current live ML backend API uses keys with underscores, but the Antenna database
+ uses keys with dashes. This migration renames all existing algorithm keys to use underscores
+ so that the keys match for new detections.
+ """
+ Algorithm = apps.get_model("ml", "Algorithm")
+ algorithms = Algorithm.objects.all()
+ for algorithm in algorithms:
+ new_key = algorithm.key.replace("-", "_")
+ if Algorithm.objects.filter(key=new_key).exclude(id=algorithm.pk).exists():
+ # Add random 6 char suffix to avoid clashes
+ new_key += "_" + "".join(random.choices(string.ascii_lowercase, k=6))
+ if algorithm.key != new_key:
+ logger.info(f"Renaming algorithm key {algorithm.key} to {new_key} for algorithm {algorithm}")
+ algorithm.key = new_key
+ algorithm.save()
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more"),
+ ]
+
+ # Rename all existing algorithm keys to use underscores instead of dashes. Ensure there are no clashes as we go.
+
+ operations = [
+ migrations.RunPython(rename_algorithm_keys, reverse_code=migrations.RunPython.noop),
+ ]
diff --git a/ami/ml/migrations/0015_update_existing_intermediate_classifications.py b/ami/ml/migrations/0015_update_existing_intermediate_classifications.py
new file mode 100644
index 000000000..a61e5cb25
--- /dev/null
+++ b/ami/ml/migrations/0015_update_existing_intermediate_classifications.py
@@ -0,0 +1,46 @@
+"""
+We have had the terminal field in the Classification model for a while now, but we have not been using it.
+Now we need to filter classifications based on this field, so we need to update the existing classifications
+where a binary / intermediate classification was used.
+"""
+
+import logging
+
+from django.db import migrations, models
+
+logger = logging.getLogger(__name__)
+
+
+MOTH_NONMOTH_LABELS = [
+ "moth",
+ "non-moth",
+ "nonmoth",
+]
+
+
+def update_classification_labels(apps, schema_editor):
+ Classification = apps.get_model("main", "Classification")
+
+ # Create regex pattern from labels to make a case-insensitive match
+ pattern = r"^(" + "|".join(MOTH_NONMOTH_LABELS) + ")$"
+
+ # Log number of updated classifications
+ logger.info(f"\nUpdating classifications with labels matching pattern: {pattern} (case insensitive)")
+
+ # Update only matching classifications
+ updated = Classification.objects.filter(taxon__name__iregex=pattern).update(terminal=False)
+
+ logger.info(f"\nUpdated {updated} moth/non-moth classifications to terminal=False")
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0014_rename_model_keys"),
+ ]
+
+ operations = [
+ migrations.RunPython(
+ update_classification_labels,
+ reverse_code=migrations.RunPython.noop,
+ ),
+ ]
diff --git a/ami/ml/migrations/0016_merge_20250117_2101.py b/ami/ml/migrations/0016_merge_20250117_2101.py
new file mode 100644
index 000000000..63bc4e283
--- /dev/null
+++ b/ami/ml/migrations/0016_merge_20250117_2101.py
@@ -0,0 +1,12 @@
+# Generated by Django 4.2.10 on 2025-01-17 21:01
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("ml", "0007_add_processing_service"),
+ ("ml", "0015_update_existing_intermediate_classifications"),
+ ]
+
+ operations = []
diff --git a/ami/ml/models/__init__.py b/ami/ml/models/__init__.py
index a5e716372..a4f28d671 100644
--- a/ami/ml/models/__init__.py
+++ b/ami/ml/models/__init__.py
@@ -1,9 +1,10 @@
-from ami.ml.models.algorithm import Algorithm
+from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap
from ami.ml.models.pipeline import Pipeline
from ami.ml.models.processing_service import ProcessingService
__all__ = [
"Algorithm",
+ "AlgorithmCategoryMap",
"Pipeline",
"ProcessingService",
]
diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py
index 9cc56d4bd..7a63a0861 100644
--- a/ami/ml/models/algorithm.py
+++ b/ami/ml/models/algorithm.py
@@ -8,22 +8,142 @@
import typing
+from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.utils.text import slugify
from ami.base.models import BaseModel
+@typing.final
+class AlgorithmCategoryMap(BaseModel):
+ """
+ A list of classification labels for a given algorithm version
+ """
+
+ data = models.JSONField(
+ help_text="Complete metadata for each label, such as id, gbif_key, lookup value, source, etc."
+ )
+ labels = ArrayField(
+ models.CharField(max_length=255),
+ default=list,
+ help_text="A simple list of string labels in the correct index order used by the model.",
+ )
+ labels_hash = models.BigIntegerField(
+ help_text="A hash of the labels for faster comparison of label sets. Created on save.",
+ null=True,
+ )
+ version = models.CharField(max_length=255, blank=True, null=True)
+ description = models.TextField(blank=True, null=True)
+ uri = models.CharField(
+ max_length=255,
+ blank=True,
+ null=True,
+ help_text=("A URI to the category map file. " "Could be a public web URL or object store path."),
+ )
+
+ algorithms: models.QuerySet[Algorithm]
+
+ def __str__(self):
+ return f"#{self.pk} with {len(self.labels)} classes ({self.version or 'unknown version'})"
+
+ @classmethod
+ def make_labels_hash(cls, labels):
+ """
+ Create a hash from the labels for faster comparison of unique label sets
+ """
+ return hash("".join(labels))
+
+ def get_category(self, label, label_field="label"):
+ # Can use JSON containment operators
+ return self.data.index(next(category for category in self.data if category[label_field] == label))
+
+ def with_taxa(self, category_field="label", only_indexes: list[int] | None = None):
+ """
+ Add Taxon objects to the category map, or None if no match
+
+ :param category_field: The field in the category data to match against the Taxon name
+ :return: The category map with the taxon objects added
+
+ @TODO need a top_n parameter to limit the number of taxa to fetch
+ @TODO consider creating missing taxa?
+ """
+
+ from ami.main.models import Taxon
+
+ if only_indexes:
+ labels_data = [self.data[i] for i in only_indexes]
+ labels_label = [self.labels[i] for i in only_indexes]
+ else:
+ labels_data = self.data
+ labels_label = self.labels
+
+ taxa = Taxon.objects.filter(models.Q(name__in=labels_label) | models.Q(search_names__overlap=labels_label))
+ taxon_map = {taxon.name: taxon for taxon in taxa}
+
+ for category in labels_data:
+ taxon = taxon_map.get(category[category_field])
+ category["taxon"] = taxon
+
+ return labels_data
+
+ def save(self, *args, **kwargs):
+ if not self.labels_hash:
+ self.labels_hash = self.make_labels_hash(self.labels)
+ super().save(*args, **kwargs)
+
+
@typing.final
class Algorithm(BaseModel):
"""A machine learning algorithm"""
name = models.CharField(max_length=255)
key = models.SlugField(max_length=255, unique=True)
+ task_type = models.CharField(
+ max_length=255,
+ default="unknown",
+ null=True,
+ choices=[
+ ("detection", "Detection"),
+ ("segmentation", "Segmentation"),
+ ("classification", "Classification"),
+ ("embedding", "Embedding"),
+ ("tracking", "Tracking"),
+ ("tagging", "Tagging"),
+ ("regression", "Regression"),
+ ("captioning", "Captioning"),
+ ("generation", "Generation"),
+ ("translation", "Translation"),
+ ("summarization", "Summarization"),
+ ("question_answering", "Question Answering"),
+ ("depth_estimation", "Depth Estimation"),
+ ("pose_estimation", "Pose Estimation"),
+ ("size_estimation", "Size Estimation"),
+ ("other", "Other"),
+ ("unknown", "Unknown"),
+ ],
+ )
description = models.TextField(blank=True)
- version = models.IntegerField(default=1)
- version_name = models.CharField(max_length=255, blank=True)
- url = models.URLField(blank=True) # URL to the model homepage, origin or docs (huggingface, wandb, etc.)
+ version = models.IntegerField(
+ default=1,
+ help_text="An internal, sortable and incrementable version number for the model.",
+ )
+ version_name = models.CharField(max_length=255, blank=True, null=True)
+ uri = models.CharField(
+ max_length=255,
+ blank=True,
+ null=True,
+ help_text=("A URI to the weights or model details. Could be a public web URL or object store path."),
+ )
+
+ category_map = models.ForeignKey(
+ AlgorithmCategoryMap,
+ on_delete=models.CASCADE,
+ blank=True,
+ null=True,
+ related_name="algorithms",
+ default=None,
+ )
# api_base_url = models.URLField(blank=True)
# api = models.CharField(max_length=255, blank=True)
@@ -31,6 +151,9 @@ class Algorithm(BaseModel):
pipelines: models.QuerySet[Pipeline]
classifications: models.QuerySet[Classification]
+ def __str__(self):
+ return f'#{self.pk} "{self.name}" ({self.key}) v{self.version}'
+
class Meta:
ordering = ["name", "version"]
@@ -38,10 +161,9 @@ class Meta:
["name", "version"],
]
- def __str__(self):
- return f'#{self.pk} "{self.name}" ({self.key}) v{self.version}'
-
def save(self, *args, **kwargs):
+ if not self.version_name:
+ self.version_name = f"{self.version}"
if not self.key:
- self.key = slugify(self.name)
+ self.key = f"{slugify(self.name)}-{self.version}"
super().save(*args, **kwargs)
diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py
index c830f69bf..4a598d8e6 100644
--- a/ami/ml/models/pipeline.py
+++ b/ami/ml/models/pipeline.py
@@ -5,16 +5,19 @@
if TYPE_CHECKING:
from ami.ml.models import ProcessingService
+import collections
+import dataclasses
import logging
+import time
import typing
+import uuid
from urllib.parse import urljoin
import requests
-from django.db import models, transaction
+from django.db import models
from django.utils.text import slugify
from django.utils.timezone import now
from django_pydantic_field import SchemaField
-from rich import print
from ami.base.models import BaseModel
from ami.base.schemas import ConfigurableStage, default_stages
@@ -29,10 +32,20 @@
Taxon,
TaxonRank,
update_calculated_fields_for_events,
+ update_occurrence_determination,
+)
+from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap
+from ami.ml.schemas import (
+ AlgorithmConfigResponse,
+ ClassificationResponse,
+ DetectionResponse,
+ PipelineRequest,
+ PipelineResultsResponse,
+ SourceImageRequest,
+ SourceImageResponse,
)
-from ami.ml.models.algorithm import Algorithm
-from ami.ml.schemas import PipelineRequest, PipelineResponse, SourceImageRequest, SourceImageResponse
from ami.ml.tasks import celery_app, create_detection_images
+from ami.utils.requests import create_session
logger = logging.getLogger(__name__)
@@ -50,20 +63,23 @@ def filter_processed_images(
for image in images:
existing_detections = image.detections.filter(detection_algorithm__in=pipeline_algorithms)
if not existing_detections.exists():
- logger.debug(f"Image {image} has no existing detections from pipeline {pipeline}")
+ logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector")
# If there are no existing detections from this pipeline, send the image
yield image
elif existing_detections.filter(classifications__isnull=True).exists():
# Check if there are detections with no classifications
- logger.debug(f"Image {image} has existing detections with no classifications from pipeline {pipeline}")
+ logger.debug(
+ f"Image {image} needs processing: has existing detections with no classifications "
+ "from pipeline {pipeline}"
+ )
yield image
else:
# If there are existing detections with classifications,
# Compare their classification algorithms to the current pipeline's algorithms
- detections_needing_classification = existing_detections.exclude(
- classifications__algorithm__in=pipeline_algorithms
- )
- if detections_needing_classification.exists():
+ pipeline_algorithm_ids = pipeline_algorithms.values_list("id", flat=True)
+ detection_algorithm_ids = existing_detections.values_list("classifications__algorithm_id", flat=True)
+
+ if not set(pipeline_algorithm_ids).issubset(set(detection_algorithm_ids)):
logger.debug(
f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}"
)
@@ -134,7 +150,7 @@ def process_images(
endpoint_url: str,
images: typing.Iterable[SourceImage],
job_id: int | None = None,
-) -> PipelineResponse:
+) -> PipelineResultsResponse:
"""
Process images using ML pipeline API.
@@ -158,7 +174,7 @@ def process_images(
if not images:
task_logger.info("No images to process")
- return PipelineResponse(
+ return PipelineResultsResponse(
pipeline=pipeline.slug,
source_images=[],
detections=[],
@@ -181,18 +197,20 @@ def process_images(
source_images=source_images,
)
- resp = requests.post(endpoint_url, json=request_data.dict())
+ session = create_session()
+ resp = session.post(endpoint_url, json=request_data.dict())
if not resp.ok:
try:
msg = resp.json()["detail"]
- except Exception:
+ except (ValueError, KeyError):
msg = str(resp.content)
if job:
job.logger.error(msg)
else:
logger.error(msg)
+ raise requests.HTTPError(msg)
- results = PipelineResponse(
+ results = PipelineResultsResponse(
pipeline=pipeline.slug,
total_time=0,
source_images=[
@@ -204,199 +222,633 @@ def process_images(
return results
results = resp.json()
- results = PipelineResponse(**results)
+ results = PipelineResultsResponse(**results)
if job:
job.logger.debug(f"Results: {results}")
detections = results.detections
classifications = [classification for detection in detections for classification in detection.classifications]
- if len(detections):
- job.logger.info(f"Found {len(detections)} detections")
- if len(classifications):
- job.logger.info(f"Found {len(classifications)} classifications")
+ job.logger.info(
+ f"Pipeline results returned {len(results.source_images)} images, {len(detections)} detections, "
+ f"{len(classifications)} classifications"
+ )
return results
-@celery_app.task(soft_time_limit=60 * 4, time_limit=60 * 5)
-def save_results(results: PipelineResponse | None = None, results_json: str | None = None, job_id: int | None = None):
+def get_or_create_algorithm_and_category_map(
+ algorithm_config: AlgorithmConfigResponse,
+ logger: logging.Logger = logger,
+) -> Algorithm:
"""
- Save results from ML pipeline API.
+ Create algorithms and category maps from a ProcessingServiceInfoResponse or a PipelineConfigResponse.
- @TODO break into task chunks.
- @TODO rewrite this!
+ :param algorithm_configs: A dictionary of algorithms from the processing services' "/info" endpoint
+ :param logger: A logger instance from the parent function
+
+ :return: A dictionary of algorithms used in the pipeline, keyed by the algorithm key
+
+ @TODO this should be called when registering a pipeline, not when saving results.
+ But currently we don't have a way to register pipelines.
"""
- created_objects = []
- job = None
+ category_map = None
+ category_map_data = algorithm_config.category_map
+ if category_map_data:
+ labels_hash = AlgorithmCategoryMap.make_labels_hash(category_map_data.labels)
+ category_map, _created = AlgorithmCategoryMap.objects.get_or_create(
+ # @TODO this is creating a new category map every time
+ # Will create a new category map if the labels are different
+ labels_hash=labels_hash,
+ version=category_map_data.version,
+ defaults={
+ "data": category_map_data.data,
+ "labels": category_map_data.labels,
+ "description": category_map_data.description,
+ "uri": category_map_data.uri,
+ },
+ )
+ if _created:
+ logger.info(f"Registered new category map {category_map}")
+ else:
+ logger.info(f"Assigned existing category map {category_map}")
+ else:
+ logger.warning(
+ f"No category map found for algorithm {algorithm_config.key} in response."
+ " Will attempt to create one from the classification results."
+ )
- if results_json:
- results = PipelineResponse.parse_raw(results_json)
- assert results, "No results data passed to save_results task"
+ algo, _created = Algorithm.objects.get_or_create(
+ key=algorithm_config.key,
+ version=algorithm_config.version,
+ defaults={
+ "name": algorithm_config.name,
+ "task_type": algorithm_config.task_type,
+ "version_name": algorithm_config.version_name,
+ "uri": algorithm_config.uri,
+ "category_map": category_map or None,
+ },
+ )
+
+ # Update fields that may have changed in the processing service, with a warning
+ fields_to_update = {
+ "task_type": algorithm_config.task_type,
+ "uri": algorithm_config.uri,
+ "category_map": category_map,
+ }
+ fields_updated = []
+ for field in fields_to_update:
+ new_value = fields_to_update[field]
+ if getattr(algo, field) != new_value:
+ logger.warning(f"Field '{field}' changed for algorithm {algo} from {getattr(algo, field)} to {new_value}")
+ setattr(algo, field, new_value)
+ fields_updated.append(field)
+ algo.save(update_fields=fields_updated)
+
+ if not algo.category_map or len(algo.category_map.data) == 0:
+ # Update existing algorithm that is missing a category map
+ algo.category_map = category_map
+ algo.save()
- pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline})
if _created:
- logger.warning(f"Pipeline choice returned by the Processing Service was not recognized! {pipeline}")
- created_objects.append(pipeline)
- algorithms_used = set()
+ logger.info(f"Registered new algorithm {algo}")
+ else:
+ logger.info(f"Assigned algorithm {algo}")
- if job_id:
- from ami.jobs.models import Job
+ return algo
- job = Job.objects.get(pk=job_id)
- job.logger.info("Saving results...")
-
- # collection_name = f"Images processed by {results.pipeline} pipeline"
- # if job_id:
- # from ami.jobs.models import Job
-
- # job = Job.objects.get(pk=job_id)
- # collection_name = f"Images processed by {results.pipeline} pipeline for job {job.name}"
-
- # collection = SourceImageCollection.objects.create(name=collection_name)
- # source_image_ids = [source_image.id for source_image in results.source_images]
- # source_images = SourceImage.objects.filter(pk__in=source_image_ids)
- # collection.images.set(source_images)
- source_images = set()
-
- for detection_resp in results.detections:
- # @TODO use bulk create, or optimize this in some way
- print(detection_resp)
- assert detection_resp.algorithm, "No detection algorithm was specified in the returned results."
- detection_algo, _created = Algorithm.objects.get_or_create(
- name=detection_resp.algorithm,
+
+def get_or_create_detection(
+ source_image: SourceImage,
+ detection_resp: DetectionResponse,
+ algorithms_used: dict[str, Algorithm],
+ save: bool = True,
+ logger: logging.Logger = logger,
+) -> tuple[Detection, bool]:
+ """
+ Create a Detection object from a DetectionResponse, or update an existing one.
+
+ :param detection_resp: A DetectionResponse object
+ :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key
+ :param created_objects: A list to store created objects
+
+ :return: A tuple of the Detection object and a boolean indicating whether it was created
+ """
+ serialized_bbox = list(detection_resp.bbox.dict().values())
+ detection_repr = f"Detection {detection_resp.source_image_id} {serialized_bbox}"
+
+ assert detection_resp.algorithm, f"No detection algorithm was specified for detection {detection_repr}"
+ try:
+ detection_algo = algorithms_used[detection_resp.algorithm.key]
+ except KeyError:
+ raise ValueError(
+ f"Detection algorithm {detection_resp.algorithm.key} is not a known algorithm. "
+ "The processing service must declare it in the /info endpoint. "
+ f"Known algorithms: {list(algorithms_used.keys())}"
)
- algorithms_used.add(detection_algo)
- if _created:
- created_objects.append(detection_algo)
- # @TODO hmmmm what to do
- source_image = SourceImage.objects.get(pk=detection_resp.source_image_id)
- source_images.add(source_image)
- existing_detection = Detection.objects.filter(
+ assert str(detection_resp.source_image_id) == str(
+ source_image.pk
+ ), f"Detection belongs to a different source image: {detection_repr}"
+
+ existing_detection = Detection.objects.filter(
+ source_image=source_image,
+ detection_algorithm=detection_algo,
+ bbox=serialized_bbox,
+ ).first()
+
+ # A detection may have a pre-existing crop image URL or not.
+ # If not, a new one will be created in a periodic background task.
+ if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"):
+ # Ensure that the crop image URL is not empty or only a slash. None is fine.
+ crop_url = detection_resp.crop_image_url
+ else:
+ crop_url = None
+
+ if existing_detection:
+ if not existing_detection.path:
+ existing_detection.path = crop_url
+ existing_detection.save()
+ logger.debug(f"Updated crop_url of existing detection {existing_detection}")
+ detection = existing_detection
+
+ else:
+ new_detection = Detection(
source_image=source_image,
+ bbox=serialized_bbox,
+ timestamp=source_image.timestamp,
+ path=crop_url,
+ detection_time=detection_resp.timestamp,
detection_algorithm=detection_algo,
- bbox=list(detection_resp.bbox.dict().values()),
- ).first()
- # Ensure that the crop image URL is not empty or only a slash. None is fine.
- if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"):
- crop_url = detection_resp.crop_image_url
+ )
+ if save:
+ new_detection.save()
+ logger.debug(f"Created new detection {new_detection}")
else:
- crop_url = None
- if existing_detection:
- if not existing_detection.path:
- existing_detection.path = crop_url
- existing_detection.save()
- print("Updated existing detection", existing_detection)
- detection = existing_detection
+ logger.debug(f"Initialized new detection {new_detection} (not saved)")
+
+ detection = new_detection
+
+ created = not existing_detection
+ return detection, created
+
+
+def create_detections(
+ detections: list[DetectionResponse],
+ algorithms_used: dict[str, Algorithm],
+ logger: logging.Logger = logger,
+) -> list[Detection]:
+ """
+ Efficiently create multiple Detection objects from a list of DetectionResponse objects, grouped by source image.
+ Using bulk create.
+
+ :param detections: A list of DetectionResponse objects
+ :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key
+ :param created_objects: A list to store created objects
+
+ :return: A list of Detection objects
+ """
+ source_image_ids = {detection.source_image_id for detection in detections}
+ source_images = SourceImage.objects.filter(pk__in=source_image_ids)
+ source_image_map = {str(source_image.pk): source_image for source_image in source_images}
+
+ existing_detections: list[Detection] = []
+ new_detections: list[Detection] = []
+ for detection_resp in detections:
+ source_image = source_image_map.get(detection_resp.source_image_id)
+ if not source_image:
+ logger.error(f"Source image {detection_resp.source_image_id} not found, skipping Detection creation")
+ continue
+
+ detection, created = get_or_create_detection(
+ source_image=source_image,
+ detection_resp=detection_resp,
+ algorithms_used=algorithms_used,
+ save=False,
+ logger=logger,
+ )
+ if created:
+ new_detections.append(detection)
else:
- new_detection = Detection.objects.create(
- source_image=source_image,
- bbox=list(detection_resp.bbox.dict().values()),
- timestamp=source_image.timestamp,
- path=crop_url,
- detection_time=detection_resp.timestamp,
- detection_algorithm=detection_algo,
- )
- new_detection.save()
- print("Created new detection", new_detection)
- created_objects.append(new_detection)
- detection = new_detection
+ existing_detections.append(detection)
- for classification in detection_resp.classifications:
- print(classification)
+ Detection.objects.bulk_create(new_detections)
+ # logger.info(f"Created {len(new_detections)} new detections for {len(source_image_ids)} source image(s)")
+ logger.info(
+ f"Created {len(new_detections)} new detections, updated {len(existing_detections)} existing detections, "
+ f"for {len(source_image_ids)} source image(s)"
+ )
- assert classification.algorithm, "No classification algorithm was specified in the returned results."
- classification_algo, _created = Algorithm.objects.get_or_create(
- name=classification.algorithm,
- )
- algorithms_used.add(classification_algo)
- if _created:
- created_objects.append(classification_algo)
+ return existing_detections + new_detections
- taxa_list, _created = TaxaList.objects.get_or_create(
- name=f"Taxa returned by {classification_algo.name}",
- )
- if _created:
- created_objects.append(taxa_list)
- taxon, _created = Taxon.objects.get_or_create(
- name=classification.classification,
- defaults={"name": classification.classification, "rank": TaxonRank.UNKNOWN},
- )
- if _created:
- created_objects.append(taxon)
+def create_category_map_for_classification(
+ classification_resp: ClassificationResponse,
+ logger: logging.Logger = logger,
+) -> AlgorithmCategoryMap:
+ """
+ Create a simple category map from a ClassificationResponse.
+ The complete category map should be created when registering the algorithm before processing images.
+
+ :param classification: A ClassificationResponse object
+
+ :return: The AlgorithmCategoryMap object
+ """
+ labels = classification_resp.labels or list(map(str, range(len(classification_resp.scores))))
+ category_map_data = [
+ {
+ "label": label,
+ "index": i,
+ }
+ for i, label in enumerate(labels)
+ ]
+ logger.info(f"Creating placeholder category map with data: {category_map_data}")
+ category_map = AlgorithmCategoryMap.objects.create(
+ data=category_map_data,
+ version=classification_resp.timestamp.isoformat(),
+ description="Placeholder category map automatically created from classification data",
+ labels=labels,
+ )
+ return category_map
+
+
+def get_or_create_taxon_for_classification(
+ algorithm: Algorithm,
+ classification_resp: ClassificationResponse,
+ logger: logging.Logger = logger,
+) -> Taxon:
+ """
+ Create a Taxon object from a ClassificationResponse and add it to a TaxaList.
+
+ :param classification: A ClassificationResponse object
+
+ :return: The Taxon object
+ """
+ taxa_list, _created = TaxaList.objects.get_or_create(
+ name=f"Taxa returned by {algorithm.name}",
+ )
+ if _created:
+ logger.info(f"Created new taxa list {taxa_list}")
+ else:
+ logger.debug(f"Using existing taxa list {taxa_list}")
+
+ # Get top label from classification scores
+ assert algorithm.category_map, f"No category map found for algorithm {algorithm}"
+ label_data: dict = algorithm.category_map.data[classification_resp.scores.index(max(classification_resp.scores))]
+ taxon, _created = Taxon.objects.get_or_create(
+ name=classification_resp.classification,
+ defaults={
+ "name": classification_resp.classification,
+ "rank": label_data.get("taxon_rank", TaxonRank.UNKNOWN),
+ },
+ )
+ if _created:
+ logger.info(f"Registered new taxon {taxon}")
+
+ taxa_list.taxa.add(taxon)
+ return taxon
+
+
+def create_classification(
+ detection: Detection,
+ classification_resp: ClassificationResponse,
+ algorithms_used: dict[str, Algorithm],
+ save: bool = True,
+ logger: logging.Logger = logger,
+) -> tuple[Classification, bool]:
+ """
+ Create a Classification object from a ClassificationResponse, or update an existing one.
+
+ :param detection: A Detection object
+ :param classification: A ClassificationResponse object
+ :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key
+ :param created_objects: A list to store created objects
+
+ :return: A tuple of the Classification object and a boolean indicating whether it was created
+ """
+ assert (
+ classification_resp.algorithm
+ ), f"No classification algorithm was specified for classification {classification_resp}"
+ logger.debug(f"Processing classification {classification_resp}")
+
+ try:
+ classification_algo = algorithms_used[classification_resp.algorithm.key]
+ except KeyError:
+ raise ValueError(
+ f"Classification algorithm {classification_resp.algorithm.key} is not a known algorithm. "
+ "The processing service must declare it in the /info endpoint. "
+ f"Known algorithms: {list(algorithms_used.keys())}"
+ )
+
+ if not classification_algo.category_map:
+ logger.warning(
+ f"Classification algorithm {classification_algo} "
+ "has no category map! "
+ "Creating one from data in the first classification if possible."
+ )
+ category_map = create_category_map_for_classification(classification_resp, logger=logger)
+ classification_algo.category_map = category_map
+ classification_algo.save()
+ classification_algo.refresh_from_db()
+
+ taxon = get_or_create_taxon_for_classification(
+ algorithm=classification_algo,
+ classification_resp=classification_resp,
+ logger=logger,
+ )
+
+ existing_classification = Classification.objects.filter(
+ detection=detection,
+ taxon=taxon,
+ algorithm=classification_algo,
+ score=max(classification_resp.scores),
+ ).first()
+
+ if existing_classification:
+ # @TODO remove this after all existing classifications have been updated (added 2024-12-20)
+ NEW_FIELDS = ["logits", "scores", "terminal", "category_map"]
+ logger.debug(
+ "Duplicate classification found: "
+ f"{existing_classification.taxon} from {existing_classification.algorithm}, "
+ f"not creating a new one, but updating new fields if they are None ({NEW_FIELDS})"
+ )
+ fields_to_update = []
+ for field in NEW_FIELDS:
+ # update new fields if they are None
+ if getattr(existing_classification, field) is None:
+ fields_to_update.append(field)
+ if fields_to_update:
+ logger.info(f"Updating fields {fields_to_update} for existing classification {existing_classification}")
+ for field in fields_to_update:
+ if field == "category_map":
+ # Use the foreign key from the classification algorithm
+ setattr(existing_classification, field, classification_algo.category_map)
+ else:
+ # Get the value from the classification response
+ setattr(existing_classification, field, getattr(classification_resp, field))
+ existing_classification.save(update_fields=fields_to_update)
+ logger.info(f"Updated existing classification {existing_classification}")
+
+ classification = existing_classification
+
+ else:
+ new_classification = Classification(
+ detection=detection,
+ taxon=taxon,
+ algorithm=classification_algo,
+ score=max(classification_resp.scores),
+ timestamp=classification_resp.timestamp or now(),
+ logits=classification_resp.logits,
+ scores=classification_resp.scores,
+ terminal=classification_resp.terminal,
+ category_map=classification_algo.category_map,
+ )
+ classification = new_classification
+
+ if save:
+ new_classification.save()
+ logger.debug(f"Created new classification {new_classification}")
+ else:
+ logger.debug(f"Initialized new classification {new_classification} (not saved)")
- taxa_list.taxa.add(taxon)
+ return classification, not existing_classification
- # @TODO this is asking for trouble
- # shouldn't we be able to get the detection from the classification?
- # also should filter by the correct detection algorithm
- # or do we use the bbox as a unique identifier?
- # then it doesn't matter what detection algorithm was used
- new_classification, created = Classification.objects.get_or_create(
+def create_classifications(
+ detections: list[Detection],
+ detection_responses: list[DetectionResponse],
+ algorithms_used: dict[str, Algorithm],
+ logger: logging.Logger = logger,
+ save: bool = True,
+) -> list[Classification]:
+ """
+ Efficiently create multiple Classification objects from a list of ClassificationResponse objects,
+ grouped by detection.
+
+ :param detection: A Detection object
+ :param classifications: A list of ClassificationResponse objects
+ :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key
+
+ :return: A list of Classification objects
+
+ @TODO bulk create all classifications for all detections in request
+ """
+ existing_classifications: list[Classification] = []
+ new_classifications: list[Classification] = []
+
+ for detection, detection_resp in zip(detections, detection_responses):
+ for classification_resp in detection_resp.classifications:
+ classification, created = create_classification(
detection=detection,
- taxon=taxon,
- algorithm=classification_algo,
- score=max(classification.scores),
- defaults={"timestamp": classification.timestamp or now()},
+ classification_resp=classification_resp,
+ algorithms_used=algorithms_used,
+ save=False,
+ logger=logger,
)
-
if created:
- # Optionally add reference to job or pipeline here
- created_objects.append(new_classification)
+ new_classifications.append(classification)
else:
- # Optionally handle the case where a duplicate is found
- logger.warn("Duplicate classification found, not creating a new one.")
+ # @TODO consider adding logits, scores and terminal state to existing classifications (new fields)
+ existing_classifications.append(classification)
+
+ Classification.objects.bulk_create(new_classifications)
+ logger.info(
+ f"Created {len(new_classifications)} new classifications, updated {len(existing_classifications)} existing "
+ f"classifications for {len(detections)} detections."
+ )
+
+ return existing_classifications + new_classifications
- # Create a new occurrence for each detection (no tracking yet)
- # @TODO remove when we implement tracking
+
+def create_and_update_occurrences_for_detections(
+ detections: list[Detection],
+ logger: logging.Logger = logger,
+):
+ """
+ Create an Occurrence object for each Detection, and update the occurrence determination.
+
+ Select the best terminal classification for the occurrence determination.
+
+ :param detection: A Detection object
+ :param classifications: A list of Classification objects
+
+ :return: The Occurrence object
+ """
+
+ # Group detections by source image id so we don't create duplicate occurrences
+ detections_by_source_image = collections.defaultdict(list)
+ for detection in detections:
+ detections_by_source_image[detection.source_image_id].append(detection)
+
+ for source_image_id, detections in detections_by_source_image.items():
+ logger.info(f"Determining occurrences for {len(detections)} detections for source image {source_image_id}")
+
+ occurrences_to_create = []
+ detections_to_update = []
+
+ for detection in detections:
if not detection.occurrence:
- occurrence = Occurrence.objects.create(
- event=source_image.event,
- deployment=source_image.deployment,
- project=source_image.project,
- determination=taxon,
- determination_score=new_classification.score,
+ occurrence = Occurrence(
+ event=detection.source_image.event,
+ deployment=detection.source_image.deployment,
+ project=detection.source_image.project,
)
+ occurrences_to_create.append(occurrence)
+ logger.debug(f"Created new occurrence {occurrence} for detection {detection}")
detection.occurrence = occurrence # type: ignore
- detection.save()
- detection.occurrence.save()
+ detections_to_update.append(detection)
+
+ occurrences = Occurrence.objects.bulk_create(occurrences_to_create)
+ logger.info(f"Created {len(occurrences)} new occurrences")
+ Detection.objects.bulk_update(detections_to_update, ["occurrence"])
+ logger.info(f"Updated {len(detections_to_update)} detections with occurrences")
+
+ occurrences_to_update = []
+ occurrences_to_leave = []
+ for detection in detections:
+ assert detection.occurrence, f"No occurrence found for detection {detection}"
+ needs_update = update_occurrence_determination(
+ detection.occurrence,
+ current_determination=detection.occurrence.determination,
+ save=False,
+ )
+ if needs_update:
+ occurrences_to_update.append(detection.occurrence)
+ else:
+ occurrences_to_leave.append(detection.occurrence)
+
+ Occurrence.objects.bulk_update(occurrences_to_update, ["determination", "determination_score"])
+ logger.info(
+ f"Updated the determination of {len(occurrences_to_update)} occurrences, "
+ f"left {len(occurrences_to_leave)} unchanged"
+ )
+
+ SourceImage.objects.get(pk=source_image_id).save()
+
+
+@dataclasses.dataclass
+class PipelineSaveResults:
+ pipeline: Pipeline
+ source_images: list[SourceImage]
+ detections: list[Detection]
+ classifications: list[Classification]
+ algorithms: dict[str, Algorithm]
+ total_time: float
+
+
+@celery_app.task(soft_time_limit=60 * 4, time_limit=60 * 5)
+def save_results(
+ results: PipelineResultsResponse | None = None,
+ results_json: str | None = None,
+ job_id: int | None = None,
+ return_created=False,
+) -> PipelineSaveResults | None:
+ """
+ Save results from ML pipeline API.
+
+ @TODO Continue improving bulk create. Group everything / all loops by source image.
+ """
+ job = None
+
+ if results_json:
+ results = PipelineResultsResponse.parse_raw(results_json)
+ assert results, "No results data passed to save_results task"
+
+ pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline})
+ if _created:
+ logger.warning(f"Pipeline choice returned by the Processing Service was not recognized! {pipeline}")
+ algorithms_used = set()
+
+ job_logger = logger
+ start_time = time.time()
+
+ if job_id:
+ from ami.jobs.models import Job
+
+ job = Job.objects.get(pk=job_id)
+ job_logger = job.logger
+
+ # @TODO set this level back to INFO
+ # job_logger.setLevel(logging.DEBUG)
+
+ if results_json:
+ results = PipelineResultsResponse.parse_raw(results_json)
+ assert results, "No results data passed to save_results task"
+ job_logger.info(f"Saving results from pipeline {results.pipeline}")
+
+ results = PipelineResultsResponse.parse_obj(results.dict())
+ assert results, "No results from pipeline to save"
+ source_images = SourceImage.objects.filter(pk__in=[int(img.id) for img in results.source_images]).distinct()
+
+ pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline})
+ if _created:
+ job_logger.warning(
+ f"The pipeline returned by the ML backend was not recognized, created a placeholder: {pipeline}"
+ )
+
+ # Algorithms and category maps should be created in advance when registering the pipeline & processing service
+ # however they are also currently available in each pipeline results response as well.
+ # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of the results
+ algorithms_used = {
+ algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger)
+ for algo_key, algo_config in results.algorithms.items()
+ }
+
+ detections = create_detections(
+ detections=results.detections,
+ algorithms_used=algorithms_used,
+ logger=job_logger,
+ )
+
+ classifications = create_classifications(
+ detections=detections,
+ detection_responses=results.detections,
+ algorithms_used=algorithms_used,
+ logger=job_logger,
+ )
+
+ # Create a new occurrence for each detection (no tracking yet)
+ # @TODO remove when we implement tracking!
+ create_and_update_occurrences_for_detections(
+ detections=detections,
+ logger=job_logger,
+ )
# Update precalculated counts on source images and events
- with transaction.atomic():
- for source_image in source_images:
- source_image.save()
+ source_images = list(source_images)
+ logger.info(f"Updating calculated fields for {len(source_images)} source images")
+ for source_image in source_images:
+ source_image.save()
image_cropping_task = create_detection_images.delay(
source_image_ids=[source_image.pk for source_image in source_images],
)
- if job:
- job.logger.info(f"Creating detection images in sub-task {image_cropping_task.id}")
+ job_logger.info(f"Creating detection images in sub-task {image_cropping_task.id}")
- event_ids = [img.event_id for img in source_images]
+ event_ids = [img.event_id for img in source_images] # type: ignore
update_calculated_fields_for_events(pks=event_ids)
registered_algos = pipeline.algorithms.all()
- for algo in algorithms_used:
+ for algo in algorithms_used.values():
# This is important for tracking what objects were processed by which algorithms
# to avoid reprocessing, and for tracking provenance.
if algo not in registered_algos:
pipeline.algorithms.add(algo)
- logger.warning(f"Added unregistered algorithm {algo} to pipeline {pipeline}")
+ job_logger.debug(f"Added algorithm {algo} to pipeline {pipeline}")
- if job:
- if len(created_objects):
- job.logger.info(f"Created {len(created_objects)} objects")
- try:
- previously_created = int(job.progress.get_stage_param("results", "objects_created").value)
- job.progress.update_stage(
- "results",
- objects_created=previously_created + len(created_objects),
- )
- except ValueError:
- pass
- else:
- job.update_progress()
+ total_time = time.time() - start_time
+ job_logger.info(f"Saved results from pipeline {pipeline} in {total_time:.2f} seconds")
+
+ if return_created:
+ """
+ By default, return None because celery tasks need special handling to return objects.
+ """
+ return PipelineSaveResults(
+ pipeline=pipeline,
+ source_images=source_images,
+ detections=detections,
+ classifications=classifications,
+ algorithms=algorithms_used,
+ total_time=total_time,
+ )
class PipelineStage(ConfigurableStage):
@@ -501,22 +953,24 @@ def process_images(self, images: typing.Iterable[SourceImage], job_id: int | Non
)
return process_images(
- endpoint_url=urljoin(processing_service.endpoint_url, "/process_images"),
+ endpoint_url=urljoin(processing_service.endpoint_url, "/process"),
pipeline=self,
images=images,
job_id=job_id,
)
- def save_results(self, results: PipelineResponse, job_id: int | None = None):
+ def save_results(self, results: PipelineResultsResponse, job_id: int | None = None):
return save_results(results=results, job_id=job_id)
- def save_results_async(self, results: PipelineResponse, job_id: int | None = None):
+ def save_results_async(self, results: PipelineResultsResponse, job_id: int | None = None):
# Returns an AsyncResult
results_json = results.json()
return save_results.delay(results_json=results_json, job_id=job_id)
def save(self, *args, **kwargs):
if not self.slug:
- # @TODO slug may only need to be unique per project
- self.slug = slugify(self.name)
+ # @TODO find a better way to generate unique identifiers
+ # consider hashing the pipeline config or using a UUID -- but both sides need to agree on the same UUID.
+ unique_suffix = str(uuid.uuid4())[:8]
+ self.slug = f"{slugify(self.name)}-v{self.version}-{unique_suffix}"
return super().save(*args, **kwargs)
diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py
index fb34c271c..3d66143d2 100644
--- a/ami/ml/models/processing_service.py
+++ b/ami/ml/models/processing_service.py
@@ -8,9 +8,8 @@
from django.db import models
from ami.base.models import BaseModel
-from ami.ml.models.algorithm import Algorithm
-from ami.ml.models.pipeline import Pipeline
-from ami.ml.schemas import PipelineRegistrationResponse, ProcessingServiceStatusResponse
+from ami.ml.models.pipeline import Pipeline, get_or_create_algorithm_and_category_map
+from ami.ml.schemas import PipelineRegistrationResponse, ProcessingServiceInfoResponse, ProcessingServiceStatusResponse
logger = logging.getLogger(__name__)
@@ -38,8 +37,6 @@ class Meta:
def create_pipelines(self):
# Call the status endpoint and get the pipelines/algorithms
resp = self.get_status()
- if resp.error:
- resp.raise_for_status()
pipelines_to_add = resp.pipeline_configs
pipelines = []
@@ -47,12 +44,19 @@ def create_pipelines(self):
algorithms_created = []
for pipeline_data in pipelines_to_add:
- pipeline, created = Pipeline.objects.get_or_create(
- name=pipeline_data.name,
- slug=pipeline_data.slug,
- version=pipeline_data.version,
- description=pipeline_data.description or "",
- )
+ pipeline = Pipeline.objects.filter(
+ models.Q(slug=pipeline_data.slug) | models.Q(name=pipeline_data.name, version=pipeline_data.version)
+ ).first()
+ created = False
+ if not pipeline:
+ pipeline = Pipeline.objects.create(
+ slug=pipeline_data.slug,
+ name=pipeline_data.name,
+ version=pipeline_data.version,
+ description=pipeline_data.description or "",
+ )
+ created = True
+
pipeline.projects.add(*self.projects.all())
self.pipelines.add(pipeline)
@@ -62,13 +66,13 @@ def create_pipelines(self):
else:
logger.info(f"Using existing pipeline {pipeline.name}.")
+ existing_algorithms = pipeline.algorithms.all()
for algorithm_data in pipeline_data.algorithms:
- algorithm, created = Algorithm.objects.get_or_create(name=algorithm_data.name, key=algorithm_data.key)
- pipeline.algorithms.add(algorithm)
-
- if created:
- logger.info(f"Successfully created algorithm {algorithm.name}.")
- algorithms_created.append(algorithm.name)
+ algorithm = get_or_create_algorithm_and_category_map(algorithm_data, logger=logger)
+ if algorithm not in existing_algorithms:
+ logger.info(f"Registered new algorithm {algorithm.name} to pipeline {pipeline.name}.")
+ pipeline.algorithms.add(algorithm)
+ pipelines_created.append(algorithm.key)
else:
logger.info(f"Using existing algorithm {algorithm.name}.")
@@ -94,28 +98,33 @@ def get_status(self):
resp = requests.get(info_url)
resp.raise_for_status()
except requests.exceptions.RequestException as e:
+ latency = time.time() - start_time
self.last_checked_live = False
+ self.last_checked_latency = latency
self.save()
error = f"Error connecting to {info_url}: {e}"
logger.error(error)
- first_response_time = time.time()
- latency = first_response_time - start_time
-
return ProcessingServiceStatusResponse(
+ error=error,
timestamp=timestamp,
request_successful=False,
+ server_live=False,
+ pipelines_online=[],
+ pipeline_configs=[],
endpoint_url=self.endpoint_url,
- error=error,
latency=latency,
)
- pipeline_configs = resp.json()
- server_live = requests.get(urljoin(self.endpoint_url, "livez")).json().get("status")
- pipelines_online = requests.get(urljoin(self.endpoint_url, "readyz")).json().get("status")
+ info_data = ProcessingServiceInfoResponse.parse_obj(resp.json())
+ pipeline_configs = info_data.pipelines
+
+ # @TODO these are likely extra requests that could be avoided
+ # @TODO add schemas for these if we keep them
+ server_live: bool = requests.get(urljoin(self.endpoint_url, "livez")).json().get("status", False)
+ pipelines_online: list[str] = requests.get(urljoin(self.endpoint_url, "readyz")).json().get("status", [])
- first_response_time = time.time()
- latency = first_response_time - start_time
+ latency = time.time() - start_time
self.last_checked_live = server_live
self.last_checked_latency = latency
self.save()
diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py
index e700e8076..97e0ca480 100644
--- a/ami/ml/schemas.py
+++ b/ami/ml/schemas.py
@@ -1,8 +1,12 @@
import datetime
+import logging
import typing
import pydantic
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
class BoundingBox(pydantic.BaseModel):
x1: float
@@ -11,25 +15,108 @@ class BoundingBox(pydantic.BaseModel):
y2: float
@classmethod
- def from_coords(cls, coords: list[float]):
+ def from_coords(cls, coords: list[float]) -> "BoundingBox":
return cls(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3])
+ def to_string(self) -> str:
+ return f"{self.x1},{self.y1},{self.x2},{self.y2}"
+
+ def to_path(self) -> str:
+ return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]])
+
+ def to_tuple(self):
+ return (self.x1, self.y1, self.x2, self.y2)
+
+
+class AlgorithmReference(pydantic.BaseModel):
+ name: str
+ key: str
+
+
+class AlgorithmCategoryMapResponse(pydantic.BaseModel):
+ data: list[dict] = pydantic.Field(
+ default_factory=dict,
+ description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.",
+ examples=[
+ [
+ {"label": "Moth", "index": 0, "gbif_key": 1234},
+ {"label": "Not a moth", "index": 1, "gbif_key": 5678},
+ ]
+ ],
+ )
+ labels: list[str] = pydantic.Field(
+ default_factory=list,
+ description="A simple list of string labels, in the correct index order used by the model.",
+ examples=[["Moth", "Not a moth"]],
+ )
+ version: str | None = pydantic.Field(
+ default=None,
+ description="The version of the category map. Can be a descriptive string or a version number.",
+ examples=["LepNet2021-with-2023-mods"],
+ )
+ description: str | None = pydantic.Field(
+ default=None,
+ description="A description of the category map used to train. e.g. source, purpose and modifications.",
+ examples=["LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations."],
+ )
+ uri: str | None = pydantic.Field(
+ default=None,
+ description="A URI to the category map file, could be a public web URL or object store path.",
+ )
+
+
+class AlgorithmConfigResponse(pydantic.BaseModel):
+ name: str
+ key: str = pydantic.Field(
+ description=("A unique key for an algorithm to lookup the category map (class list) and other metadata."),
+ )
+ description: str | None = None
+ task_type: str | None = pydantic.Field(
+ default=None,
+ description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.",
+ examples=["detection", "classification", "segmentation", "embedding"],
+ )
+ version: int = pydantic.Field(
+ default=1,
+ description="A sortable version number for the model. Increment this number when the model is updated.",
+ )
+ version_name: str | None = pydantic.Field(
+ default=None,
+ description="A complete version name e.g. '2021-01-01', 'LepNet2021'.",
+ )
+ uri: str | None = pydantic.Field(
+ default=None,
+ description="A URI to the weight or model details, could be a public web URL or object store path.",
+ )
+ category_map: AlgorithmCategoryMapResponse | None = None
+
+ class Config:
+ extra = "ignore"
+
class ClassificationResponse(pydantic.BaseModel):
classification: str
- labels: list[str] = []
+ labels: list[str] | None = pydantic.Field(
+ default=None,
+ description=(
+ "A list of all possible labels for the model, in the correct order. "
+ "Omitted if the model has too many labels to include for each classification in the response. "
+ "Use the category map from the algorithm to get the full list of labels and metadata."
+ ),
+ )
scores: list[float] = []
+ logits: list[float] | None = None
inference_time: float | None = None
- algorithm: str | None = None
- timestamp: datetime.datetime
+ algorithm: AlgorithmReference
terminal: bool = True
+ timestamp: datetime.datetime
class DetectionResponse(pydantic.BaseModel):
source_image_id: str
bbox: BoundingBox
inference_time: float | None = None
- algorithm: str | None = None
+ algorithm: AlgorithmReference
timestamp: datetime.datetime
crop_image_url: str | None = None
classifications: list[ClassificationResponse] = []
@@ -46,6 +133,9 @@ class SourceImageResponse(pydantic.BaseModel):
id: str
url: str
+ class Config:
+ extra = "ignore"
+
KnownPipelineChoices = typing.Literal[
"panama_moths_2023",
@@ -59,10 +149,15 @@ class PipelineRequest(pydantic.BaseModel):
source_images: list[SourceImageRequest]
-class PipelineResponse(pydantic.BaseModel):
+class PipelineResultsResponse(pydantic.BaseModel):
# pipeline: PipelineChoice
pipeline: str
- total_time: float | None
+ algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field(
+ default_factory=dict,
+ description="A dictionary of all algorithms used in the pipeline, including their class list and other "
+ "metadata, keyed by the algorithm key.",
+ )
+ total_time: float
source_images: list[SourceImageResponse]
detections: list[DetectionResponse]
errors: list | str | None = None
@@ -85,11 +180,6 @@ class PipelineStage(pydantic.BaseModel):
description: str | None = None
-class AlgorithmConfig(pydantic.BaseModel):
- name: str
- key: str
-
-
class PipelineConfig(pydantic.BaseModel):
"""A configurable pipeline."""
@@ -97,17 +187,32 @@ class PipelineConfig(pydantic.BaseModel):
slug: str
version: int
description: str | None = None
- algorithms: list[AlgorithmConfig] = []
+ algorithms: list[AlgorithmConfigResponse] = []
stages: list[PipelineStage] = []
+class ProcessingServiceInfoResponse(pydantic.BaseModel):
+ """
+ Information about the processing service returned from the Processing Service backend.
+ """
+
+ name: str
+ description: str | None = None
+ pipelines: list[PipelineConfig] = []
+ algorithms: list[AlgorithmConfigResponse] = []
+
+
class ProcessingServiceStatusResponse(pydantic.BaseModel):
+ """
+ Status response returned by the Antenna API about the Processing Service.
+ """
+
timestamp: datetime.datetime
request_successful: bool
pipeline_configs: list[PipelineConfig] = []
error: str | None = None
server_live: bool | None = None
- pipelines_online: list[str] | str = "pipelines unavailable"
+ pipelines_online: list[str] = []
endpoint_url: str
latency: float
diff --git a/ami/ml/serializers.py b/ami/ml/serializers.py
index 24e08a064..1611b3ea4 100644
--- a/ami/ml/serializers.py
+++ b/ami/ml/serializers.py
@@ -4,11 +4,26 @@
from ami.main.api.serializers import DefaultSerializer
from ami.main.models import Project
-from .models.algorithm import Algorithm
+from .models.algorithm import Algorithm, AlgorithmCategoryMap
from .models.pipeline import Pipeline, PipelineStage
from .models.processing_service import ProcessingService
+class AlgorithmCategoryMapSerializer(DefaultSerializer):
+ class Meta:
+ model = AlgorithmCategoryMap
+ fields = [
+ "id",
+ "labels",
+ "data",
+ "algorithms",
+ "version",
+ "uri",
+ "created_at",
+ "updated_at",
+ ]
+
+
class AlgorithmSerializer(DefaultSerializer):
class Meta:
model = Algorithm
@@ -18,9 +33,11 @@ class Meta:
"name",
"key",
"description",
- "url",
+ "uri",
"version",
"version_name",
+ "task_type",
+ "category_map",
"created_at",
"updated_at",
]
diff --git a/ami/ml/tasks.py b/ami/ml/tasks.py
index a38847c9d..972233251 100644
--- a/ami/ml/tasks.py
+++ b/ami/ml/tasks.py
@@ -1,4 +1,5 @@
import logging
+import time
from ami.ml.media import create_detection_images_from_source_image
from ami.tasks import default_soft_time_limit, default_time_limit
@@ -42,6 +43,8 @@ def process_source_images_async(pipeline_choice: str, endpoint_url: str, image_i
def create_detection_images(source_image_ids: list[int]):
from ami.main.models import SourceImage
+ start_time = time.time()
+
logger.debug(f"Creating detection images for {len(source_image_ids)} capture(s)")
for source_image in SourceImage.objects.filter(pk__in=source_image_ids):
@@ -51,6 +54,9 @@ def create_detection_images(source_image_ids: list[int]):
except Exception as e:
logger.error(f"Error creating detection images for SourceImage {source_image.pk}: {str(e)}")
+ total_time = time.time() - start_time
+ logger.info(f"Created detection images for {len(source_image_ids)} capture(s) in {total_time:.2f} seconds")
+
@celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit)
def remove_duplicate_classifications(project_id: int | None = None, dry_run: bool = False) -> int:
diff --git a/ami/ml/tests.py b/ami/ml/tests.py
index b2dc5fb70..2fbd5e723 100644
--- a/ami/ml/tests.py
+++ b/ami/ml/tests.py
@@ -1,21 +1,24 @@
import datetime
+import unittest
from django.test import TestCase
from rest_framework.test import APIRequestFactory, APITestCase
-from rich import print
from ami.base.serializers import reverse_with_params
from ami.main.models import Classification, Detection, Project, SourceImage, SourceImageCollection
from ami.ml.models import Algorithm, Pipeline, ProcessingService
-from ami.ml.models.pipeline import collect_images, save_results
+from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results
from ami.ml.schemas import (
+ AlgorithmConfigResponse,
+ AlgorithmReference,
BoundingBox,
ClassificationResponse,
DetectionResponse,
- PipelineResponse,
+ PipelineResultsResponse,
SourceImageResponse,
)
from ami.tests.fixtures.main import create_captures_from_files, create_processing_service, setup_test_project
+from ami.tests.fixtures.ml import ALGORITHM_CHOICES
from ami.users.models import User
@@ -104,7 +107,8 @@ def setUp(self):
self.test_images = [image for image, frame in self.captures]
self.processing_service_instance = create_processing_service(self.project)
self.processing_service = self.processing_service_instance
- self.pipeline = self.processing_service_instance.pipelines.all().filter(slug="constant").first()
+ assert self.processing_service_instance.pipelines.exists()
+ self.pipeline = self.processing_service_instance.pipelines.all().get(slug="constant")
def test_run_pipeline(self):
# Send images to Processing Service to process and return detections
@@ -112,6 +116,72 @@ def test_run_pipeline(self):
pipeline_response = self.pipeline.process_images(self.test_images, job_id=None)
assert pipeline_response.detections
+ def test_created_category_maps(self):
+ # Send images to ML backend to process and return detections
+ assert self.pipeline
+ pipeline_response = self.pipeline.process_images(self.test_images)
+ save_results(pipeline_response, return_created=True)
+
+ source_images = SourceImage.objects.filter(pk__in=[image.id for image in pipeline_response.source_images])
+ detections = Detection.objects.filter(source_image__in=source_images).select_related(
+ "detection_algorithm",
+ "detection_algorithm__category_map",
+ )
+ assert detections.count() > 0
+ for detection in detections:
+ # No detection algorithm should have category map at this time (but this may change!)
+ assert detection.detection_algorithm
+ assert detection.detection_algorithm.category_map is None
+
+ # Ensure that all classification algorithms have a category map
+ classification_taxa = set()
+ for classification in detection.classifications.all().select_related(
+ "algorithm",
+ "algorithm__category_map",
+ ):
+ assert classification.algorithm is not None
+ assert classification.category_map is not None
+ assert classification.algorithm.category_map == classification.category_map
+
+ _, top_score = list(classification.predictions(sort=True))[0]
+ assert top_score == classification.score
+
+ top_taxon, top_taxon_score = list(classification.predictions_with_taxa(sort=True))[0]
+ assert top_taxon == classification.taxon
+ assert top_taxon_score == classification.score
+
+ classification_taxa.add(top_taxon)
+
+ # Check the occurrence determination taxon
+ assert detection.occurrence
+ assert detection.occurrence.determination in classification_taxa
+
+ def test_alignment_of_predictions_and_category_map(self):
+ # Ensure that the scores and labels are aligned
+ pipeline = self.processing_service_instance.pipelines.all().get(slug="random")
+ pipeline_response = pipeline.process_images(self.test_images)
+ results = save_results(pipeline_response, return_created=True)
+ assert results is not None, "Expecected results to be returned in a PipelineSaveResults object"
+ assert results.classifications, "Expected classifications to be returned in the results"
+ for classification in results.classifications:
+ assert classification.scores
+ taxa_with_scores = list(classification.predictions_with_taxa(sort=True))
+ assert taxa_with_scores
+ assert classification.score == taxa_with_scores[0][1]
+ assert classification.taxon == taxa_with_scores[0][0]
+
+ def test_top_n_alignment(self):
+ # Ensure that the top_n parameter works
+ pipeline = self.processing_service_instance.pipelines.all().get(slug="random")
+ pipeline_response = pipeline.process_images(self.test_images)
+ results = save_results(pipeline_response, return_created=True)
+ assert results is not None, "Expecected results to be returned in a PipelineSaveResults object"
+ assert results.classifications, "Expected classifications to be returned in the results"
+ for classification in results.classifications:
+ top_n = classification.top_n(n=3)
+ assert classification.score == top_n[0]["score"]
+ assert classification.taxon == top_n[0]["taxon"]
+
class TestPipeline(TestCase):
def setUp(self):
@@ -131,39 +201,72 @@ def setUp(self):
self.pipeline = Pipeline.objects.create(
name="Test Pipeline",
)
+
self.algorithms = {
- "detector": Algorithm.objects.create(name="Test Object Detector"),
- "binary_classifier": Algorithm.objects.create(name="Test Filter"),
- "species_classifier": Algorithm.objects.create(name="Test Classifier"),
+ key: get_or_create_algorithm_and_category_map(val) for key, val in ALGORITHM_CHOICES.items()
}
- self.pipeline.algorithms.set(self.algorithms.values())
+ self.pipeline.algorithms.set([algo for algo in self.algorithms.values()])
def test_create_pipeline(self):
- self.assertEqual(self.pipeline.slug, "test-pipeline")
- self.assertEqual(self.pipeline.algorithms.count(), 3)
+ assert self.pipeline.slug.startswith("test-pipeline")
+ self.assertEqual(self.pipeline.algorithms.count(), len(ALGORITHM_CHOICES))
for algorithm in self.pipeline.algorithms.all():
- self.assertIn(algorithm.key, ["test-object-detector", "test-filter", "test-classifier"])
+ assert isinstance(algorithm, Algorithm)
+ self.assertIn(algorithm.key, [algo.key for algo in ALGORITHM_CHOICES.values()])
def test_collect_images(self):
images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
assert len(images) == 2
- def fake_pipeline_results(self, source_images: list[SourceImage], pipeline: Pipeline):
+ def fake_pipeline_results(
+ self,
+ source_images: list[SourceImage],
+ pipeline: Pipeline,
+ alt_species_classifier: AlgorithmConfigResponse | None = None,
+ ):
+ # @TODO use the pipeline passed in to get the algorithms
source_image_results = [SourceImageResponse(id=image.pk, url=image.path) for image in source_images]
+ detector = ALGORITHM_CHOICES["random-detector"]
+ binary_classifier = ALGORITHM_CHOICES["random-binary-classifier"]
+ assert binary_classifier.category_map
+
+ if alt_species_classifier is None:
+ species_classifier = ALGORITHM_CHOICES["random-species-classifier"]
+ else:
+ species_classifier = alt_species_classifier
+ assert species_classifier.category_map
+
detection_results = [
DetectionResponse(
source_image_id=image.pk,
bbox=BoundingBox(x1=0.0, y1=0.0, x2=1.0, y2=1.0),
inference_time=0.4,
- algorithm=self.algorithms["detector"].name,
+ algorithm=AlgorithmReference(
+ name=detector.name,
+ key=detector.key,
+ ),
timestamp=datetime.datetime.now(),
classifications=[
ClassificationResponse(
- classification="Test taxon",
- labels=["Test taxon"],
+ classification=binary_classifier.category_map.labels[0],
+ labels=None,
+ scores=[0.9213],
+ algorithm=AlgorithmReference(
+ name=binary_classifier.name,
+ key=binary_classifier.key,
+ ),
+ timestamp=datetime.datetime.now(),
+ terminal=False,
+ ),
+ ClassificationResponse(
+ classification=species_classifier.category_map.labels[0],
+ labels=None,
scores=[0.64333],
- algorithm=self.algorithms["species_classifier"].name,
+ algorithm=AlgorithmReference(
+ name=species_classifier.name,
+ key=species_classifier.key,
+ ),
timestamp=datetime.datetime.now(),
terminal=True,
),
@@ -171,22 +274,27 @@ def fake_pipeline_results(self, source_images: list[SourceImage], pipeline: Pipe
)
for image in self.test_images
]
- fake_results = PipelineResponse(
- pipeline=self.pipeline.slug,
- total_time=0.0,
+ fake_results = PipelineResultsResponse(
+ pipeline=pipeline.slug,
+ algorithms={
+ detector.key: detector,
+ binary_classifier.key: binary_classifier,
+ species_classifier.key: species_classifier,
+ },
+ total_time=0.01,
source_images=source_image_results,
detections=detection_results,
)
return fake_results
def test_save_results(self):
- saved_objects = save_results(self.fake_pipeline_results(self.test_images, self.pipeline))
+ results = self.fake_pipeline_results(self.test_images, self.pipeline)
+ save_results(results)
for image in self.test_images:
image.save()
self.assertEqual(image.detections_count, 1)
- print(saved_objects)
# @TODO test the cached counts for detections, etc are updated on Events, Deployments, etc.
def no_test_skip_existing_results(self):
@@ -196,7 +304,7 @@ def no_test_skip_existing_results(self):
total_images = len(images)
self.assertEqual(total_images, self.image_collection.images.count())
- save_results(self.fake_pipeline_results(images, self.pipeline))
+ save_results(self.fake_pipeline_results(images, self.pipeline), return_created=True)
images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
@@ -221,17 +329,22 @@ def test_skip_existing_with_new_detector(self):
images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
total_images = len(images)
self.assertEqual(total_images, self.image_collection.images.count())
- save_results(self.fake_pipeline_results(images, self.pipeline))
+ pipeline_response = self.fake_pipeline_results(images, self.pipeline)
+ save_results(pipeline_response)
+ # Find the fist algo used where task_type is classification
+ classifiers = [algo for algo in pipeline_response.algorithms.values() if algo.task_type == "classification"]
+ last_classifier = Algorithm.objects.get(key=classifiers[-1].key)
self.pipeline.algorithms.set(
[
Algorithm.objects.create(name="NEW Object Detector 2.0"),
- self.algorithms["species_classifier"], # Same classifier
+ last_classifier,
]
)
images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
remaining_images_to_process = len(images_again)
self.assertEqual(remaining_images_to_process, total_images)
+ @unittest.skip("Not implemented yet")
def test_skip_existing_with_new_classifier(self):
"""
@TODO add support for skipping the detection model if only the classifier has changed.
@@ -239,10 +352,16 @@ def test_skip_existing_with_new_classifier(self):
images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
total_images = len(images)
self.assertEqual(total_images, self.image_collection.images.count())
- save_results(self.fake_pipeline_results(images, self.pipeline))
+ pipeline_response = self.fake_pipeline_results(images, self.pipeline)
+ # Find the fist algo used where task_type is detection
+ first_detector_in_response = next(
+ algo for algo in pipeline_response.algorithms.values() if algo.task_type == "detection"
+ )
+ first_detector = Algorithm.objects.get(key=first_detector_in_response.key)
+ save_results(pipeline_response)
self.pipeline.algorithms.set(
[
- self.algorithms["detector"], # Same object detector
+ first_detector,
Algorithm.objects.create(name="NEW Classifier 2.0"),
]
)
@@ -259,14 +378,21 @@ def _test_skip_existing_per_batch_during_processing(self):
def test_unknown_algorithm_returned_by_processing_service(self):
fake_results = self.fake_pipeline_results(self.test_images, self.pipeline)
- new_detector_name = "Unknown Detector 5.1b-mobile"
- new_classifier_name = "Unknown Classifier 3.0b-mega"
+ new_detector = AlgorithmConfigResponse(
+ name="Unknown Detector 5.1b-mobile", key="unknown-detector", task_type="detection"
+ )
+ new_classifier = AlgorithmConfigResponse(
+ name="Unknown Classifier 3.0b-mega", key="unknown-classifier", task_type="classification"
+ )
+
+ fake_results.algorithms[new_detector.key] = new_detector
+ fake_results.algorithms[new_classifier.key] = new_classifier
for detection in fake_results.detections:
- detection.algorithm = new_detector_name
+ detection.algorithm = AlgorithmReference(name=new_detector.name, key=new_detector.key)
for classification in detection.classifications:
- classification.algorithm = new_classifier_name
+ classification.algorithm = AlgorithmReference(name=new_classifier.name, key=new_classifier.key)
current_total_algorithm_count = Algorithm.objects.count()
@@ -278,39 +404,50 @@ def test_unknown_algorithm_returned_by_processing_service(self):
self.assertEqual(new_algorithm_count, current_total_algorithm_count + 2)
# Ensure new algorithms were also added to the pipeline
- self.assertTrue(self.pipeline.algorithms.filter(name=new_detector_name).exists())
- self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier_name).exists())
+ self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name, key=new_detector.key).exists())
+ self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name, key=new_classifier.key).exists())
- def no_test_reprocessing_after_unknown_algorithm_added(self):
+ @unittest.skip("Not implemented yet")
+ def test_reprocessing_after_unknown_algorithm_added(self):
# @TODO fix issue with "None" algorithm on some detections
images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
- saved_objects = save_results(self.fake_pipeline_results(images, self.pipeline))
+ save_results(self.fake_pipeline_results(images, self.pipeline))
+
+ new_detector = AlgorithmConfigResponse(
+ name="Unknown Detector 5.1b-mobile", key="unknown-detector", task_type="detection"
+ )
+ new_classifier = AlgorithmConfigResponse(
+ name="Unknown Classifier 3.0b-mega", key="unknown-classifier", task_type="classification"
+ )
- new_detector_name = "Unknown Detector 5.1b-mobile"
- new_classifier_name = "Unknown Classifier 3.0b-mega"
fake_results = self.fake_pipeline_results(images, self.pipeline)
+
# Change the algorithm names to unknown ones
for detection in fake_results.detections:
- detection.algorithm = new_detector_name
+ detection.algorithm = AlgorithmReference(name=new_detector.name, key=new_detector.key)
for classification in detection.classifications:
- classification.algorithm = new_classifier_name
+ classification.algorithm = AlgorithmReference(name=new_classifier.name, key=new_classifier.key)
+
+ fake_results.algorithms[new_detector.key] = new_detector
+ fake_results.algorithms[new_classifier.key] = new_classifier
# print("FAKE RESULTS")
# print(fake_results)
# print("END FAKE RESULTS")
- saved_objects = save_results(fake_results) or []
- saved_detections = [obj for obj in saved_objects if isinstance(obj, Detection)]
- saved_classifications = [obj for obj in saved_objects if isinstance(obj, Classification)]
+ saved_objects = save_results(fake_results, return_created=True)
+ assert saved_objects is not None
+ saved_detections = saved_objects.detections
+ saved_classifications = saved_objects.classifications
for obj in saved_detections:
assert obj.detection_algorithm # For type checker, not the test
# Ensure the new detector was used for the detection
- self.assertEqual(obj.detection_algorithm.name, new_detector_name)
+ self.assertEqual(obj.detection_algorithm.name, new_detector.name)
# Ensure each detection has classification objects
self.assertTrue(obj.classifications.exists())
@@ -323,21 +460,21 @@ def no_test_reprocessing_after_unknown_algorithm_added(self):
assert obj.algorithm # For type checker, not the test
# Ensure the new classifier was used for the classification
- self.assertEqual(obj.algorithm.name, new_classifier_name)
+ self.assertEqual(obj.algorithm.name, new_classifier.name)
# Ensure each classification has the correct detection object
self.assertIn(obj.detection, saved_detections, "Wrong detection object for classification object.")
# Ensure new algorithms were added to the pipeline
- self.assertTrue(self.pipeline.algorithms.filter(name=new_detector_name).exists())
- self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier_name).exists())
+ self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name).exists())
+ self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name).exists())
detection_algos_used = Detection.objects.all().values_list("detection_algorithm__name", flat=True).distinct()
- self.assertTrue(new_detector_name in detection_algos_used)
+ self.assertTrue(new_detector.name in detection_algos_used)
# Ensure None is not in the list
self.assertFalse(None in detection_algos_used)
classification_algos_used = Classification.objects.all().values_list("algorithm__name", flat=True)
- self.assertTrue(new_classifier_name in classification_algos_used)
+ self.assertTrue(new_classifier.name in classification_algos_used)
# Ensure None is not in the list
self.assertFalse(None in classification_algos_used)
@@ -345,3 +482,93 @@ def no_test_reprocessing_after_unknown_algorithm_added(self):
images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
remaining_images_to_process = len(images_again)
self.assertEqual(remaining_images_to_process, 0)
+
+ def test_yes_reprocess_if_new_terminal_algorithm_same_intermediate(self):
+ """
+ Test two pipelines with the same detector and same moth/non-moth classifier, but a new species classifier.
+
+ The first pipeline should process the images and save the results.
+ The second pipeline should reprocess the images.
+ """
+
+ images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline))
+ assert len(images), "No images to process"
+
+ detector = Algorithm.objects.get(key="random-detector")
+ binary_classifier = Algorithm.objects.get(key="random-binary-classifier")
+ old_species_classifier = Algorithm.objects.get(key="random-species-classifier")
+
+ Detection.objects.all().delete()
+ results = save_results(self.fake_pipeline_results(images, self.pipeline), return_created=True)
+ assert results is not None, "Expecected results to be returned in a PipelineSaveResults object"
+
+ for raw_detection in results.detections:
+ self.assertEqual(raw_detection.detection_algorithm, detector)
+
+ # Ensure all results have the binary classifier and the old species classifier
+ for saved_detection in Detection.objects.all():
+ self.assertEqual(saved_detection.detection_algorithm, detector)
+ # Assert that the binary classifier was used
+ self.assertTrue(
+ saved_detection.classifications.filter(algorithm=binary_classifier).exists(),
+ "Binary classifier not used in first run",
+ )
+ # Assert that the old species classifier was used
+ self.assertTrue(
+ saved_detection.classifications.filter(algorithm=old_species_classifier).exists(),
+ "Old species classifier not used in first run",
+ )
+
+ # Get another species classifier
+ new_species_classifier_key = "constant-species-classifier"
+ new_species_classifier = Algorithm.objects.get(key=new_species_classifier_key)
+ # new_species_classifier_response = ALGORITHM_CHOICES[new_species_classifier_key]
+
+ # Create a new pipeline with the same detector and the new species classifier
+ new_pipeline = Pipeline.objects.create(
+ name="New Pipeline",
+ )
+
+ new_pipeline.algorithms.set(
+ [
+ detector,
+ binary_classifier,
+ new_species_classifier,
+ ]
+ )
+
+ # Process the images with the new pipeline
+ images_again = list(collect_images(collection=self.image_collection, pipeline=new_pipeline))
+ remaining_images_to_process = len(images_again)
+ self.assertEqual(remaining_images_to_process, len(images), "Images not re-processed with new pipeline")
+
+
+class TestAlgorithmCategoryMaps(TestCase):
+ def setUp(self):
+ self.algorithm_responses = {
+ key: get_or_create_algorithm_and_category_map(val) for key, val in ALGORITHM_CHOICES.items()
+ }
+ self.algorithms = {key: Algorithm.objects.get(key=key) for key in ALGORITHM_CHOICES.keys()}
+
+ def test_create_algorithms_and_category_map(self):
+ assert len(self.algorithms) > 0
+ assert (
+ Algorithm.objects.filter(
+ key__in=self.algorithms.keys(),
+ )
+ .exclude(category_map=None)
+ .count()
+ ) > 0
+
+ def test_algorithm_category_maps(self):
+ for algorithm in Algorithm.objects.filter(
+ key__in=self.algorithms.keys(),
+ ).exclude(category_map=None):
+ assert algorithm.category_map # For type checker, not the test
+ assert algorithm.category_map.labels
+ assert algorithm.category_map.labels_hash
+ assert algorithm.category_map.data
+
+ # Ensure the full labels in the data match the simple, ordered list of labels
+ sorted_data = sorted(algorithm.category_map.data, key=lambda x: x["index"])
+ assert [category["label"] for category in sorted_data] == algorithm.category_map.labels
diff --git a/ami/ml/views.py b/ami/ml/views.py
index 0ffafa525..8f4851f83 100644
--- a/ami/ml/views.py
+++ b/ami/ml/views.py
@@ -12,10 +12,15 @@
from ami.main.models import SourceImage
from ami.utils.requests import get_active_project, project_id_doc_param
-from .models.algorithm import Algorithm
+from .models.algorithm import Algorithm, AlgorithmCategoryMap
from .models.pipeline import Pipeline
from .models.processing_service import ProcessingService
-from .serializers import AlgorithmSerializer, PipelineSerializer, ProcessingServiceSerializer
+from .serializers import (
+ AlgorithmCategoryMapSerializer,
+ AlgorithmSerializer,
+ PipelineSerializer,
+ ProcessingServiceSerializer,
+)
logger = logging.getLogger(__name__)
@@ -37,6 +42,22 @@ class AlgorithmViewSet(DefaultViewSet):
search_fields = ["name"]
+class AlgorithmCategoryMapViewSet(DefaultViewSet):
+ """
+ API endpoint that allows algorithm category maps to be viewed or edited.
+ """
+
+ queryset = AlgorithmCategoryMap.objects.all()
+ serializer_class = AlgorithmCategoryMapSerializer
+ filterset_fields = ["algorithms"]
+ ordering_fields = [
+ "algorithms",
+ "created_at",
+ "updated_at",
+ "version",
+ ]
+
+
class PipelineViewSet(DefaultViewSet):
"""
API endpoint that allows pipelines to be viewed or edited.
diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py
index 9f4a0613e..185b1095c 100644
--- a/ami/tests/fixtures/main.py
+++ b/ami/tests/fixtures/main.py
@@ -42,7 +42,8 @@ def create_processing_service(project):
processing_service_to_add = {
"name": "Test Processing Service",
"projects": [{"name": project.name}],
- "endpoint_url": "http://processing_service:2000",
+ # "endpoint_url": "http://processing_service:2000",
+ "endpoint_url": "http://ml_backend:2000",
}
processing_service, created = ProcessingService.objects.get_or_create(
diff --git a/ami/tests/fixtures/ml.py b/ami/tests/fixtures/ml.py
new file mode 100644
index 000000000..f4bd7cb54
--- /dev/null
+++ b/ami/tests/fixtures/ml.py
@@ -0,0 +1,81 @@
+from ami.ml.schemas import AlgorithmCategoryMapResponse, AlgorithmConfigResponse
+
+RANDOM_DETECTOR = AlgorithmConfigResponse(
+ name="Random Detector",
+ key="random-detector",
+ task_type="detection",
+ description="Return bounding boxes at random locations within the image bounds.",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-detector",
+ category_map=None,
+)
+
+RANDOM_BINARY_CLASSIFIER = AlgorithmConfigResponse(
+ name="Random binary classifier",
+ key="random-binary-classifier",
+ task_type="classification",
+ description="Randomly return a classification of 'Moth' or 'Not a moth'",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-binary-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {"index": 0, "gbif_key": "1234", "label": "Moth", "source": "manual"},
+ {"index": 1, "gbif_key": "4543", "label": "Not a moth", "source": "manual"},
+ ],
+ labels=["Moth", "Not a moth"],
+ version="v1",
+ description="Class mapping for a simple binary classifier",
+ uri="https://huggingface.co/RolnickLab/random-binary-classifier/classes.txt",
+ ),
+)
+
+RANDOM_SPECIES_CLASSIFIER = AlgorithmConfigResponse(
+ name="Random species classifier",
+ key="random-species-classifier",
+ task_type="classification",
+ description="A random species classifier",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-species-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {"index": 0, "gbif_key": "1234", "label": "Vanessa atalanta", "source": "manual"},
+ {"index": 1, "gbif_key": "4543", "label": "Vanessa cardui", "source": "manual"},
+ {"index": 2, "gbif_key": "7890", "label": "Vanessa itea", "source": "manual"},
+ ],
+ labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"],
+ version="v1",
+ description="",
+ uri="https://huggigface.co/RolnickLab/random-species-classifier/classes.txt",
+ ),
+)
+
+
+CONSTANT_SPECIES_CLASSIFIER = AlgorithmConfigResponse(
+ name="Constant species classifier",
+ key="constant-species-classifier",
+ task_type="classification",
+ description="A species classifier that always returns the same species",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/constant-species-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {"index": 0, "gbif_key": "1234", "label": "Vanessa atalanta", "source": "manual"},
+ {"index": 1, "gbif_key": "4543", "label": "Vanessa cardui", "source": "manual"},
+ {"index": 2, "gbif_key": "7890", "label": "Vanessa itea", "source": "manual"},
+ ],
+ labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"],
+ version="v1",
+ description="",
+ uri="https://huggigface.co/RolnickLab/constant-species-classifier/classes.txt",
+ ),
+)
+ALGORITHM_CHOICES = {
+ RANDOM_DETECTOR.key: RANDOM_DETECTOR,
+ RANDOM_BINARY_CLASSIFIER.key: RANDOM_BINARY_CLASSIFIER,
+ RANDOM_SPECIES_CLASSIFIER.key: RANDOM_SPECIES_CLASSIFIER,
+ CONSTANT_SPECIES_CLASSIFIER.key: CONSTANT_SPECIES_CLASSIFIER,
+}
diff --git a/ami/tests/fixtures/storage.py b/ami/tests/fixtures/storage.py
index 52675919e..d53c782ff 100644
--- a/ami/tests/fixtures/storage.py
+++ b/ami/tests/fixtures/storage.py
@@ -42,7 +42,7 @@ def populate_bucket(
config: s3.S3Config,
subdir: str = "deployment_1",
num_nights: int = 2,
- images_per_day: int = 6,
+ images_per_day: int = 3,
minutes_interval: int = 45,
minutes_interval_variation: int = 10,
skip_existing: bool = True,
diff --git a/ami/utils/requests.py b/ami/utils/requests.py
index 50edb0652..969c0d0aa 100644
--- a/ami/utils/requests.py
+++ b/ami/utils/requests.py
@@ -1,11 +1,53 @@
+import requests
from django.forms import FloatField
from drf_spectacular.utils import OpenApiParameter
+from requests.adapters import HTTPAdapter
from rest_framework.request import Request
+from urllib3.util import Retry
from ami.main.models import Project
+def create_session(
+ retries: int = 3,
+ backoff_factor: int = 2,
+ status_forcelist: tuple[int, ...] = (500, 502, 503, 504),
+) -> requests.Session:
+ """
+ Create a requests Session with retry capabilities.
+
+ Args:
+ retries: Maximum number of retries
+ backoff_factor: Backoff factor for retries
+ status_forcelist: HTTP status codes to retry on
+
+ Returns:
+ Session configured with retry behavior
+ """
+ session = requests.Session()
+ retry = Retry(
+ total=retries,
+ read=retries,
+ connect=retries,
+ backoff_factor=backoff_factor,
+ status_forcelist=status_forcelist,
+ )
+ adapter = HTTPAdapter(max_retries=retry)
+ session.mount("http://", adapter)
+ session.mount("https://", adapter)
+ return session
+
+
def get_active_classification_threshold(request: Request) -> float:
+ """
+ Get the active classification threshold from request parameters.
+
+ Args:
+ request: The incoming request object
+
+ Returns:
+ The classification threshold value, defaulting to 0 if not specified
+ """
# Look for a query param to filter by score
classification_threshold = request.query_params.get("classification_threshold")
diff --git a/config/api_router.py b/config/api_router.py
index 0d45dcb50..d2a776fbc 100644
--- a/config/api_router.py
+++ b/config/api_router.py
@@ -28,6 +28,7 @@
router.register(r"occurrences", views.OccurrenceViewSet)
router.register(r"taxa", views.TaxonViewSet)
router.register(r"ml/algorithms", ml_views.AlgorithmViewSet)
+router.register(r"ml/labels", ml_views.AlgorithmCategoryMapViewSet)
router.register(r"ml/pipelines", ml_views.PipelineViewSet)
router.register(r"ml/processing_services", ml_views.ProcessingServiceViewSet)
router.register(r"classifications", views.ClassificationViewSet)
diff --git a/config/settings/base.py b/config/settings/base.py
index acd2baf5f..ebb86c76d 100644
--- a/config/settings/base.py
+++ b/config/settings/base.py
@@ -90,7 +90,7 @@
"drf_spectacular",
"django_filters",
"anymail",
- "cachalot",
+ # "cachalot",
]
LOCAL_APPS = [
@@ -242,6 +242,22 @@
SENDGRID_SANDBOX_MODE_IN_DEBUG = False
SENDGRID_ECHO_TO_STDOUT = True
+# CACHES
+# ------------------------------------------------------------------------------
+# https://docs.djangoproject.com/en/dev/ref/settings/#caches
+CACHES = {
+ "default": {
+ "BACKEND": "django_redis.cache.RedisCache",
+ "LOCATION": env("REDIS_URL", default=None),
+ "OPTIONS": {
+ "CLIENT_CLASS": "django_redis.client.DefaultClient",
+ # Mimicing memcache behavior.
+ # https://github.com/jazzband/django-redis#memcached-exceptions-behavior
+ "IGNORE_EXCEPTIONS": True,
+ },
+ }
+}
+
# ADMIN
# ------------------------------------------------------------------------------
# Django Admin URL.
diff --git a/config/settings/local.py b/config/settings/local.py
index d20889160..c2f58afa0 100644
--- a/config/settings/local.py
+++ b/config/settings/local.py
@@ -23,15 +23,6 @@
"django",
] + env.list("DJANGO_ALLOWED_HOSTS", default=[])
-# CACHES
-# ------------------------------------------------------------------------------
-# https://docs.djangoproject.com/en/dev/ref/settings/#caches
-CACHES = {
- "default": {
- "BACKEND": "django.core.cache.backends.locmem.LocMemCache",
- "LOCATION": "",
- }
-}
# EMAIL
# ------------------------------------------------------------------------------
@@ -60,6 +51,11 @@
INSTALLED_APPS = ["whitenoise.runserver_nostatic"] + INSTALLED_APPS # noqa: F405
+# Long queries can be a problem in development, this should stop them after 30s
+database_options = DATABASES["default"].get("OPTIONS", {}) # noqa: F405
+database_options["options"] = "-c statement_timeout=30s"
+DATABASES["default"]["OPTIONS"] = database_options # noqa: F405
+
# django-debug-toolbar
# ------------------------------------------------------------------------------
# https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#prerequisites
diff --git a/config/settings/production.py b/config/settings/production.py
index 9f072b01c..e4abdf44e 100644
--- a/config/settings/production.py
+++ b/config/settings/production.py
@@ -22,20 +22,6 @@
# ------------------------------------------------------------------------------
DATABASES["default"]["CONN_MAX_AGE"] = env.int("CONN_MAX_AGE", default=60) # noqa: F405
-# CACHES
-# ------------------------------------------------------------------------------
-CACHES = {
- "default": {
- "BACKEND": "django_redis.cache.RedisCache",
- "LOCATION": env("REDIS_URL"),
- "OPTIONS": {
- "CLIENT_CLASS": "django_redis.client.DefaultClient",
- # Mimicing memcache behavior.
- # https://github.com/jazzband/django-redis#memcached-exceptions-behavior
- "IGNORE_EXCEPTIONS": True,
- },
- }
-}
# SECURITY
# ------------------------------------------------------------------------------
diff --git a/docker-compose.yml b/docker-compose.yml
index 347c39248..478f85101 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -23,7 +23,6 @@ services:
- postgres
- redis
- minio-init
- - processing_service
volumes:
- .:/app:z
env_file:
@@ -140,10 +139,10 @@ services:
- ./compose/local/minio/init.sh:/etc/minio/init.sh
entrypoint: /etc/minio/init.sh
- processing_service:
+ ml_backend:
build:
- context: ./processing_services/example
+ context: ./ml_backends/example
volumes:
- - ./processing_services/example/:/app:processing_service
+ - ./ml_backends/example/:/app
ports:
- "2005:2000"
diff --git a/processing_services/example/api/algorithms.py b/processing_services/example/api/algorithms.py
new file mode 100644
index 000000000..4636083aa
--- /dev/null
+++ b/processing_services/example/api/algorithms.py
@@ -0,0 +1,119 @@
+from .schemas import AlgorithmCategoryMapResponse, AlgorithmConfigResponse
+
+RANDOM_DETECTOR = AlgorithmConfigResponse(
+ name="Random Detector",
+ key="random-detector",
+ task_type="detection",
+ description="Return bounding boxes at random locations within the image bounds.",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-detector",
+ category_map=None,
+)
+
+CONSTANT_DETECTOR = AlgorithmConfigResponse(
+ name="Constant Detector",
+ key="constant-detector",
+ task_type="detection",
+ description="Return a fixed bounding box at a fixed location within the image bounds.",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/constant-detector",
+ category_map=None,
+)
+
+RANDOM_BINARY_CLASSIFIER = AlgorithmConfigResponse(
+ name="Random binary classifier",
+ key="random-binary-classifier",
+ task_type="classification",
+ description="Randomly return a classification of 'Moth' or 'Not a moth'",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-binary-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {
+ "index": 0,
+ "gbif_key": "1234",
+ "label": "Moth",
+ "source": "manual",
+ "taxon_rank": "SUPERFAMILY",
+ },
+ {
+ "index": 1,
+ "gbif_key": "4543",
+ "label": "Not a moth",
+ "source": "manual",
+ "taxon_rank": "ORDER",
+ },
+ ],
+ labels=["Moth", "Not a moth"],
+ version="v1",
+ description="A simple binary classifier",
+ uri="https://huggingface.co/RolnickLab/random-binary-classifier",
+ ),
+)
+
+CONSTANT_CLASSIFIER = AlgorithmConfigResponse(
+ name="Constant classifier",
+ key="constant-classifier",
+ task_type="classification",
+ description="Always return a classification of 'Moth'",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/constant-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {
+ "index": 0,
+ "gbif_key": "1234",
+ "label": "Moth",
+ "source": "manual",
+ "taxon_rank": "SUPERFAMILY",
+ }
+ ],
+ labels=["Moth"],
+ version="v1",
+ description="A classifier that always returns 'Moth'",
+ uri="https://huggingface.co/RolnickLab/constant-classifier",
+ ),
+)
+
+RANDOM_SPECIES_CLASSIFIER = AlgorithmConfigResponse(
+ name="Random species classifier",
+ key="random-species-classifier",
+ task_type="classification",
+ description="A random species classifier",
+ version=1,
+ version_name="v1",
+ uri="https://huggingface.co/RolnickLab/random-species-classifier",
+ category_map=AlgorithmCategoryMapResponse(
+ data=[
+ {
+ "index": 0,
+ "gbif_key": "1234",
+ "label": "Vanessa atalanta",
+ "source": "manual",
+ "taxon_rank": "SPECIES",
+ },
+ {
+ "index": 1,
+ "gbif_key": "4543",
+ "label": "Vanessa cardui",
+ "source": "manual",
+ "taxon_rank": "SPECIES",
+ },
+ {
+ "index": 2,
+ "gbif_key": "7890",
+ "label": "Vanessa itea",
+ "source": "manual",
+ "taxon_rank": "SPECIES",
+ },
+ ],
+ labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"],
+ version="v1",
+ description="A simple species classifier",
+ uri="https://huggigface.co/RolnickLab/random-species-classifier",
+ ),
+)
diff --git a/processing_services/example/api/api.py b/processing_services/example/api/api.py
index e7c0d27a8..e69de29bb 100644
--- a/processing_services/example/api/api.py
+++ b/processing_services/example/api/api.py
@@ -1,111 +0,0 @@
-"""
-Fast API interface for processing images through the localization and classification pipelines.
-"""
-
-import logging
-import time
-
-import fastapi
-
-from .pipeline import ConstantPipeline, DummyPipeline
-from .schemas import (
- AlgorithmConfig,
- PipelineConfig,
- PipelineRequest,
- PipelineResponse,
- SourceImage,
- SourceImageResponse,
-)
-
-logger = logging.getLogger(__name__)
-
-app = fastapi.FastAPI()
-
-pipeline1 = PipelineConfig(
- name="ML Dummy Pipeline",
- slug="dummy",
- version=1,
- algorithms=[
- AlgorithmConfig(name="Dummy Detector", key="1"),
- AlgorithmConfig(name="Random Detector", key="2"),
- AlgorithmConfig(name="Always Moth Classifier", key="3"),
- ],
-)
-
-pipeline2 = PipelineConfig(
- name="ML Constant Pipeline",
- slug="constant",
- version=1,
- algorithms=[
- AlgorithmConfig(name="Dummy Detector", key="1"),
- AlgorithmConfig(name="Random Detector", key="2"),
- AlgorithmConfig(name="Always Moth Classifier", key="3"),
- ],
-)
-
-pipelines = [pipeline1, pipeline2]
-
-
-@app.get("/")
-async def root():
- return fastapi.responses.RedirectResponse("/docs")
-
-
-@app.get("/info", tags=["services"])
-async def info() -> list[PipelineConfig]:
- return pipelines
-
-
-# Check if the server is online
-@app.get("/livez", tags=["health checks"])
-async def livez():
- return fastapi.responses.JSONResponse(status_code=200, content={"status": True})
-
-
-# Check if the pipelines are ready to process data
-@app.get("/readyz", tags=["health checks"])
-async def readyz():
- if pipelines:
- return fastapi.responses.JSONResponse(
- status_code=200, content={"status": [pipeline.slug for pipeline in pipelines]}
- )
- else:
- return fastapi.responses.JSONResponse(status_code=503, content={"status": "pipelines unavailable"})
-
-
-@app.post("/process_images", tags=["services"])
-async def process(data: PipelineRequest) -> PipelineResponse:
- pipeline_slug = data.pipeline
-
- source_image_results = [SourceImageResponse(**image.model_dump()) for image in data.source_images]
- source_images = [SourceImage(**image.model_dump()) for image in data.source_images]
-
- start_time = time.time()
-
- if pipeline_slug == "constant":
- pipeline = ConstantPipeline(source_images=source_images) # returns same detections
- else:
- pipeline = DummyPipeline(source_images=source_images) # returns random detections
-
- try:
- results = pipeline.run()
- except Exception as e:
- logger.error(f"Error running pipeline: {e}")
- raise fastapi.HTTPException(status_code=422, detail=f"{e}")
-
- end_time = time.time()
- seconds_elapsed = float(end_time - start_time)
-
- response = PipelineResponse(
- pipeline=data.pipeline,
- source_images=source_image_results,
- detections=results,
- total_time=seconds_elapsed,
- )
- return response
-
-
-if __name__ == "__main__":
- import uvicorn
-
- uvicorn.run(app, host="0.0.0.0", port=2000)
diff --git a/processing_services/example/api/pipeline.py b/processing_services/example/api/pipeline.py
deleted file mode 100644
index a19076949..000000000
--- a/processing_services/example/api/pipeline.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import datetime
-import math
-import random
-
-from .schemas import BoundingBox, Classification, Detection, SourceImage
-
-
-def make_random_bbox(source_image_width: int, source_image_height: int):
- # Make a random box.
- # Ensure that the box is within the image bounds and the bottom right corner is greater than the top left corner.
- x1 = random.randint(0, source_image_width)
- x2 = random.randint(0, source_image_width)
- y1 = random.randint(0, source_image_height)
- y2 = random.randint(0, source_image_height)
-
- return BoundingBox(
- x1=min(x1, x2),
- y1=min(y1, y2),
- x2=max(x1, x2),
- y2=max(y1, y2),
- )
-
-
-def generate_adaptive_grid_bounding_boxes(image_width: int, image_height: int, num_boxes: int) -> list[BoundingBox]:
- # Estimate grid size based on num_boxes
- grid_size: int = math.ceil(math.sqrt(num_boxes))
-
- cell_width: float = image_width / grid_size
- cell_height: float = image_height / grid_size
-
- boxes: list[BoundingBox] = []
-
- for _ in range(num_boxes):
- # Select a random cell
- row: int = random.randint(0, grid_size - 1)
- col: int = random.randint(0, grid_size - 1)
-
- # Calculate the cell's boundaries
- cell_x1: float = col * cell_width
- cell_y1: float = row * cell_height
-
- # Generate a random box within the cell
- # Ensure the box is between 50% and 100% of the cell size
- box_width: float = random.uniform(cell_width * 0.5, cell_width)
- box_height: float = random.uniform(cell_height * 0.5, cell_height)
-
- x1: float = cell_x1 + random.uniform(0, cell_width - box_width)
- y1: float = cell_y1 + random.uniform(0, cell_height - box_height)
- x2: float = x1 + box_width
- y2: float = y1 + box_height
-
- boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2))
-
- return boxes
-
-
-def make_fake_detections(source_image: SourceImage, num_detections: int = 10):
- source_image.open(raise_exception=True)
- assert source_image.width is not None and source_image.height is not None
- bboxes = generate_adaptive_grid_bounding_boxes(source_image.width, source_image.height, num_detections)
- timestamp = datetime.datetime.now()
-
- return [
- Detection(
- source_image_id=source_image.id,
- bbox=bbox,
- timestamp=timestamp,
- algorithm="Random Detector",
- classifications=[
- Classification(
- classification="moth",
- labels=["moth"],
- scores=[random.random()],
- timestamp=timestamp,
- algorithm="Always Moth Classifier",
- )
- ],
- )
- for bbox in bboxes
- ]
-
-
-def make_constant_detections(source_image: SourceImage, num_detections: int = 10):
- source_image.open(raise_exception=True)
- assert source_image.width is not None and source_image.height is not None
-
- # Define a fixed bounding box size and position relative to image size
- box_width, box_height = source_image.width // 4, source_image.height // 4
- start_x, start_y = source_image.width // 8, source_image.height // 8
- bboxes = [BoundingBox(x1=start_x, y1=start_y, x2=start_x + box_width, y2=start_y + box_height)]
- timestamp = datetime.datetime.now()
-
- return [
- Detection(
- source_image_id=source_image.id,
- bbox=bbox,
- timestamp=timestamp,
- algorithm="Fixed Detector",
- classifications=[
- Classification(
- classification="moth",
- labels=["moth"],
- scores=[0.9], # Constant score for each detection
- timestamp=timestamp,
- algorithm="Always Moth Classifier",
- )
- ],
- )
- for bbox in bboxes
- ]
-
-
-class DummyPipeline:
- source_images: list[SourceImage]
-
- def __init__(self, source_images: list[SourceImage]):
- self.source_images = source_images
-
- def run(self) -> list[Detection]:
- results = [make_fake_detections(source_image) for source_image in self.source_images]
- # Flatten the list of lists
- return [item for sublist in results for item in sublist]
-
-
-class ConstantPipeline:
- source_images: list[SourceImage]
-
- def __init__(self, source_images: list[SourceImage]):
- self.source_images = source_images
-
- def run(self) -> list[Detection]:
- results = [make_constant_detections(source_image) for source_image in self.source_images]
- # Flatten the list of lists
- return [item for sublist in results for item in sublist]
diff --git a/processing_services/example/api/pipelines.py b/processing_services/example/api/pipelines.py
new file mode 100644
index 000000000..0d955b417
--- /dev/null
+++ b/processing_services/example/api/pipelines.py
@@ -0,0 +1,216 @@
+import datetime
+import math
+import random
+
+from . import algorithms
+from .schemas import (
+ AlgorithmConfigResponse,
+ AlgorithmReference,
+ BoundingBox,
+ ClassificationResponse,
+ DetectionResponse,
+ PipelineConfigResponse,
+ SourceImage,
+)
+
+
+def make_random_bbox(source_image_width: int, source_image_height: int):
+ # Make a random box.
+ # Ensure that the box is within the image bounds and the bottom right corner is greater than the
+ # top left corner.
+ x1 = random.randint(0, source_image_width)
+ x2 = random.randint(0, source_image_width)
+ y1 = random.randint(0, source_image_height)
+ y2 = random.randint(0, source_image_height)
+
+ return BoundingBox(
+ x1=min(x1, x2),
+ y1=min(y1, y2),
+ x2=max(x1, x2),
+ y2=max(y1, y2),
+ )
+
+
+def generate_adaptive_grid_bounding_boxes(image_width: int, image_height: int, num_boxes: int) -> list[BoundingBox]:
+ # Estimate grid size based on num_boxes
+ grid_size: int = math.ceil(math.sqrt(num_boxes))
+
+ cell_width: float = image_width / grid_size
+ cell_height: float = image_height / grid_size
+
+ boxes: list[BoundingBox] = []
+
+ for _ in range(num_boxes):
+ # Select a random cell
+ row: int = random.randint(0, grid_size - 1)
+ col: int = random.randint(0, grid_size - 1)
+
+ # Calculate the cell's boundaries
+ cell_x1: float = col * cell_width
+ cell_y1: float = row * cell_height
+
+ # Generate a random box within the cell
+ # Ensure the box is between 50% and 100% of the cell size
+ box_width: float = random.uniform(cell_width * 0.5, cell_width)
+ box_height: float = random.uniform(cell_height * 0.5, cell_height)
+
+ x1: float = cell_x1 + random.uniform(0, cell_width - box_width)
+ y1: float = cell_y1 + random.uniform(0, cell_height - box_height)
+ x2: float = x1 + box_width
+ y2: float = y1 + box_height
+
+ boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2))
+
+ return boxes
+
+
+def make_random_prediction(
+ algorithm: AlgorithmConfigResponse,
+ terminal: bool = True,
+ max_labels: int = 2,
+) -> ClassificationResponse:
+ assert algorithm.category_map is not None
+ category_labels = algorithm.category_map.labels
+ logits = [random.random() for _ in category_labels]
+ softmax = [math.exp(logit) / sum([math.exp(logit) for logit in logits]) for logit in logits]
+ top_class = category_labels[softmax.index(max(softmax))]
+ return ClassificationResponse(
+ classification=top_class,
+ labels=category_labels if len(category_labels) <= max_labels else None,
+ scores=softmax,
+ logits=logits,
+ timestamp=datetime.datetime.now(),
+ algorithm=AlgorithmReference(name=algorithm.name, key=algorithm.key),
+ terminal=terminal,
+ )
+
+
+def make_random_detections(source_image: SourceImage, num_detections: int = 10):
+ source_image.open(raise_exception=True)
+ assert source_image.width is not None and source_image.height is not None
+ bboxes = generate_adaptive_grid_bounding_boxes(source_image.width, source_image.height, num_detections)
+ timestamp = datetime.datetime.now()
+
+ return [
+ DetectionResponse(
+ source_image_id=source_image.id,
+ bbox=bbox,
+ timestamp=timestamp,
+ algorithm=AlgorithmReference(
+ name=algorithms.RANDOM_DETECTOR.name,
+ key=algorithms.RANDOM_DETECTOR.key,
+ ),
+ classifications=[
+ make_random_prediction(
+ algorithm=algorithms.RANDOM_BINARY_CLASSIFIER,
+ terminal=False,
+ ),
+ make_random_prediction(
+ algorithm=algorithms.RANDOM_SPECIES_CLASSIFIER,
+ terminal=True,
+ ),
+ ],
+ )
+ for bbox in bboxes
+ ]
+
+
+def make_constant_detections(source_image: SourceImage, num_detections: int = 10):
+ source_image.open(raise_exception=True)
+ assert source_image.width is not None and source_image.height is not None
+
+ # Define a fixed bounding box size and position relative to image size
+ box_width, box_height = source_image.width // 4, source_image.height // 4
+ start_x, start_y = source_image.width // 8, source_image.height // 8
+ bboxes = [BoundingBox(x1=start_x, y1=start_y, x2=start_x + box_width, y2=start_y + box_height)]
+ timestamp = datetime.datetime.now()
+
+ assert algorithms.CONSTANT_CLASSIFIER.category_map is not None
+ labels = algorithms.CONSTANT_CLASSIFIER.category_map.labels
+
+ return [
+ DetectionResponse(
+ source_image_id=source_image.id,
+ bbox=bbox,
+ timestamp=timestamp,
+ algorithm=AlgorithmReference(name=algorithms.CONSTANT_DETECTOR.name, key=algorithms.CONSTANT_DETECTOR.key),
+ classifications=[
+ ClassificationResponse(
+ classification=labels[0],
+ labels=labels,
+ scores=[0.9], # Constant score for each detection
+ timestamp=timestamp,
+ algorithm=AlgorithmReference(
+ name=algorithms.CONSTANT_CLASSIFIER.name, key=algorithms.CONSTANT_CLASSIFIER.key
+ ),
+ )
+ ],
+ )
+ for bbox in bboxes
+ ]
+
+
+class Pipeline:
+ source_images: list[SourceImage]
+
+ def __init__(self, source_images: list[SourceImage]):
+ self.source_images = source_images
+
+ def run(self) -> list[DetectionResponse]:
+ raise NotImplementedError("Subclasses must implement the run method")
+
+ config = PipelineConfigResponse(
+ name="Base Pipeline",
+ slug="base",
+ description="A base class for all pipelines.",
+ version=1,
+ algorithms=[],
+ )
+
+
+class RandomPipeline(Pipeline):
+ """
+ A pipeline that returns detections in random positions within the image bounds with random classifications.
+ """
+
+ def run(self) -> list[DetectionResponse]:
+ results = [make_random_detections(source_image) for source_image in self.source_images]
+ # Flatten the list of lists
+ return [item for sublist in results for item in sublist]
+
+ config = PipelineConfigResponse(
+ name="Random Pipeline",
+ slug="random",
+ description=(
+ "A pipeline that returns detections in random positions within the image bounds "
+ "with random classifications."
+ ),
+ version=1,
+ algorithms=[
+ algorithms.RANDOM_DETECTOR,
+ algorithms.RANDOM_BINARY_CLASSIFIER,
+ algorithms.RANDOM_SPECIES_CLASSIFIER,
+ ],
+ )
+
+
+class ConstantPipeline(Pipeline):
+ """
+ A pipeline that always returns a detection in the same position with a fixed classification.
+ """
+
+ def run(self) -> list[DetectionResponse]:
+ results = [make_constant_detections(source_image) for source_image in self.source_images]
+ # Flatten the list of lists
+ return [item for sublist in results for item in sublist]
+
+ config = PipelineConfigResponse(
+ name="Constant Pipeline",
+ slug="constant",
+ description="A pipeline that always returns a detection in the same position with a fixed classification.",
+ version=1,
+ algorithms=[
+ algorithms.CONSTANT_DETECTOR,
+ algorithms.CONSTANT_CLASSIFIER,
+ ],
+ )
diff --git a/processing_services/example/api/schemas.py b/processing_services/example/api/schemas.py
index adb4a16ee..3396e1337 100644
--- a/processing_services/example/api/schemas.py
+++ b/processing_services/example/api/schemas.py
@@ -29,6 +29,9 @@ def to_string(self):
def to_path(self):
return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]])
+ def to_tuple(self):
+ return (self.x1, self.y1, self.x2, self.y2)
+
class SourceImage(pydantic.BaseModel):
model_config = pydantic.ConfigDict(extra="ignore", arbitrary_types_allowed=True)
@@ -65,34 +68,52 @@ def open(self, raise_exception=False) -> PIL.Image.Image | None:
return self._pil
-class Classification(pydantic.BaseModel):
+class AlgorithmReference(pydantic.BaseModel):
+ name: str
+ key: str
+
+
+class ClassificationResponse(pydantic.BaseModel):
classification: str
- labels: list[str] = []
- scores: list[float] = []
+ labels: list[str] | None = pydantic.Field(
+ default=None,
+ description=(
+ "A list of all possible labels for the model, in the correct order. "
+ "Omitted if the model has too many labels to include for each classification in the response. "
+ "Use the category map from the algorithm to get the full list of labels and metadata."
+ ),
+ )
+ scores: list[float] = pydantic.Field(
+ default_factory=list,
+ description="The calibrated probabilities for each class label, most commonly the softmax output.",
+ )
+ logits: list[float] = pydantic.Field(
+ default_factory=list,
+ description="The raw logits output by the model, before any calibration or normalization.",
+ )
inference_time: float | None = None
- algorithm: str
+ algorithm: AlgorithmReference
terminal: bool = True
timestamp: datetime.datetime
-class Detection(pydantic.BaseModel):
+class DetectionResponse(pydantic.BaseModel):
source_image_id: str
bbox: BoundingBox
inference_time: float | None = None
- algorithm: str
+ algorithm: AlgorithmReference
timestamp: datetime.datetime
crop_image_url: str | None = None
- classifications: list[Classification] = []
+ classifications: list[ClassificationResponse] = []
class SourceImageRequest(pydantic.BaseModel):
- # @TODO bring over new SourceImage & b64 validation from the lepsAI repo
+ model_config = pydantic.ConfigDict(extra="ignore")
+
id: str
url: str
# b64: str | None = None
-
- class Config:
- extra = "ignore"
+ # @TODO bring over new SourceImage & b64 validation from the lepsAI repo
class SourceImageResponse(pydantic.BaseModel):
@@ -102,7 +123,68 @@ class SourceImageResponse(pydantic.BaseModel):
url: str
-PipelineChoice = typing.Literal["dummy", "constant"]
+class AlgorithmCategoryMapResponse(pydantic.BaseModel):
+ data: list[dict] = pydantic.Field(
+ default_factory=dict,
+ description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.",
+ examples=[
+ [
+ {"label": "Moth", "index": 0, "gbif_key": 1234},
+ {"label": "Not a moth", "index": 1, "gbif_key": 5678},
+ ]
+ ],
+ )
+ labels: list[str] = pydantic.Field(
+ default_factory=list,
+ description="A simple list of string labels, in the correct index order used by the model.",
+ examples=[["Moth", "Not a moth"]],
+ )
+ version: str | None = pydantic.Field(
+ default=None,
+ description="The version of the category map. Can be a descriptive string or a version number.",
+ examples=["LepNet2021-with-2023-mods"],
+ )
+ description: str | None = pydantic.Field(
+ default=None,
+ description="A description of the category map used to train. e.g. source, purpose and modifications.",
+ examples=["LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations."],
+ )
+ uri: str | None = pydantic.Field(
+ default=None,
+ description="A URI to the category map file, could be a public web URL or object store path.",
+ )
+
+
+class AlgorithmConfigResponse(pydantic.BaseModel):
+ name: str
+ key: str = pydantic.Field(
+ description=("A unique key for an algorithm to lookup the category map (class list) and other metadata."),
+ )
+ description: str | None = None
+ task_type: str | None = pydantic.Field(
+ default=None,
+ description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.",
+ examples=["detection", "classification", "segmentation", "embedding"],
+ )
+ version: int = pydantic.Field(
+ default=1,
+ description="A sortable version number for the model. Increment this number when the model is updated.",
+ )
+ version_name: str | None = pydantic.Field(
+ default=None,
+ description="A complete version name e.g. '2021-01-01', 'LepNet2021'.",
+ )
+ uri: str | None = pydantic.Field(
+ default=None,
+ description="A URI to the weights or model details, could be a public web URL or object store path.",
+ )
+ category_map: AlgorithmCategoryMapResponse | None = None
+
+ class Config:
+ extra = "ignore"
+
+
+PipelineChoice = typing.Literal["random", "constant"]
class PipelineRequest(pydantic.BaseModel):
@@ -117,18 +199,23 @@ class Config:
"source_images": [
{
"id": "123",
- "url": "https://example.com/image.jpg",
+ "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg",
}
],
}
}
-class PipelineResponse(pydantic.BaseModel):
+class PipelineResultsResponse(pydantic.BaseModel):
pipeline: PipelineChoice
+ algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field(
+ default_factory=dict,
+ description="A dictionary of all algorithms used in the pipeline, including their class list and other "
+ "metadata, keyed by the algorithm key.",
+ )
total_time: float
source_images: list[SourceImageResponse]
- detections: list[Detection]
+ detections: list[DetectionResponse]
class PipelineStageParam(pydantic.BaseModel):
@@ -148,17 +235,34 @@ class PipelineStage(pydantic.BaseModel):
description: str | None = None
-class AlgorithmConfig(pydantic.BaseModel):
- name: str
- key: str
-
-
-class PipelineConfig(pydantic.BaseModel):
- """A configurable pipeline."""
+class PipelineConfigResponse(pydantic.BaseModel):
+ """Details about a pipeline, its algorithms and category maps."""
name: str
slug: str
version: int
description: str | None = None
- algorithms: list[AlgorithmConfig] = []
+ algorithms: list[AlgorithmConfigResponse] = []
stages: list[PipelineStage] = []
+
+
+class ProcessingServiceInfoResponse(pydantic.BaseModel):
+ """Information about the processing service."""
+
+ name: str = pydantic.Field(example="Mila Research Lab - Moth AI Services")
+ description: str | None = pydantic.Field(
+ default=None,
+ examples=["Algorithms developed by the Mila Research Lab for analysis of moth images."],
+ )
+ pipelines: list[PipelineConfigResponse] = pydantic.Field(
+ default=list,
+ examples=[
+ [
+ PipelineConfigResponse(name="Random Pipeline", slug="random", version=1, algorithms=[]),
+ ]
+ ],
+ )
+ # algorithms: list[AlgorithmConfigResponse] = pydantic.Field(
+ # default=list,
+ # examples=[RANDOM_BINARY_CLASSIFIER],
+ # )
diff --git a/processing_services/example/api/test.py b/processing_services/example/api/test.py
index 717367314..187298094 100644
--- a/processing_services/example/api/test.py
+++ b/processing_services/example/api/test.py
@@ -3,7 +3,7 @@
from fastapi.testclient import TestClient
from .api import app
-from .pipeline import DummyPipeline
+from .pipelines import RandomPipeline
from .schemas import PipelineRequest, SourceImage, SourceImageRequest
@@ -13,7 +13,7 @@ def test_dummy_pipeline(self):
SourceImage(id="1", url="https://example.com/image1.jpg"),
SourceImage(id="2", url="https://example.com/image2.jpg"),
]
- pipeline = DummyPipeline(source_images=source_images)
+ pipeline = RandomPipeline(source_images=source_images)
detections = pipeline.run()
self.assertEqual(len(detections), 20)
@@ -45,7 +45,7 @@ def test_process(self):
]
source_image_requests = [SourceImageRequest(**image.dict()) for image in source_images]
request = PipelineRequest(pipeline="random", source_images=source_image_requests)
- response = self.client.post("/pipeline/process", json=request.dict())
+ response = self.client.post("/process", json=request.dict())
self.assertEqual(response.status_code, 200)
data = response.json()
diff --git a/processing_services/example/api/utils.py b/processing_services/example/api/utils.py
index c70120215..119723ae5 100644
--- a/processing_services/example/api/utils.py
+++ b/processing_services/example/api/utils.py
@@ -5,15 +5,23 @@
import pathlib
import re
import tempfile
-import urllib.request
from urllib.parse import urlparse
import PIL.Image
+import PIL.ImageFile
+import requests
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
+PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+# This is polite and required by some hosts
+# see: https://foundation.wikimedia.org/wiki/Policy:User-Agent_policy
+USER_AGENT = "AntennaInsectDataPlatform/1.0 (https://insectai.org)"
+
+
def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path:
"""
Fetch a file from a URL or local path. If the path is a URL, download the file.
@@ -44,17 +52,20 @@ def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path:
else:
logger.info(f"Downloading {path_or_url} to {local_filepath}")
- try:
- resulting_filepath, _headers = urllib.request.urlretrieve(url=path_or_url, filename=local_filepath)
- except Exception as e:
- raise Exception(f"Could not retrieve {path_or_url}: {e}")
+ headers = {"User-Agent": USER_AGENT}
+ response = requests.get(path_or_url, stream=True, headers=headers)
+ response.raise_for_status() # Raise an exception for HTTP errors
+
+ with open(local_filepath, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
- resulting_filepath = pathlib.Path(resulting_filepath)
+ resulting_filepath = pathlib.Path(local_filepath).resolve()
logger.info(f"Downloaded to {resulting_filepath}")
return resulting_filepath
-def open_image(fp: str | bytes | pathlib.Path, raise_exception: bool = True) -> PIL.Image.Image | None:
+def open_image(fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True) -> PIL.Image.Image | None:
"""
Wrapper from PIL.Image.open that handles errors and converts to RGB.
"""
diff --git a/ui/src/data-services/models/algorithm.ts b/ui/src/data-services/models/algorithm.ts
index 1ab7c0f55..d27f2e02c 100644
--- a/ui/src/data-services/models/algorithm.ts
+++ b/ui/src/data-services/models/algorithm.ts
@@ -27,8 +27,12 @@ export class Algorithm {
return this._algorithm.name
}
- get url(): string {
- return this._algorithm.url
+ get key(): string {
+ return this._algorithm.key
+ }
+
+ get uri(): string {
+ return this._algorithm.uri
}
get updatedAt(): string | undefined {
diff --git a/ui/src/data-services/models/job-details.ts b/ui/src/data-services/models/job-details.ts
index 33180d082..e993cf91e 100644
--- a/ui/src/data-services/models/job-details.ts
+++ b/ui/src/data-services/models/job-details.ts
@@ -16,11 +16,11 @@ export class JobDetails extends Job {
}
get errors(): string[] {
- return this._job.progress.errors ?? []
+ return this._job.logs.stderr ?? []
}
get logs(): string[] {
- return this._job.progress.logs ?? []
+ return this._job.logs.stdout ?? []
}
get stages(): {
diff --git a/ui/src/data-services/models/occurrence-details.ts b/ui/src/data-services/models/occurrence-details.ts
index 7376f1971..516b6ab5b 100644
--- a/ui/src/data-services/models/occurrence-details.ts
+++ b/ui/src/data-services/models/occurrence-details.ts
@@ -14,6 +14,8 @@ export interface Identification {
taxon: Taxon
comment?: string
algorithm?: Algorithm
+ score?: number
+ terminal?: boolean
userPermissions: UserPermission[]
createdAt: string
}
@@ -30,6 +32,7 @@ export interface HumanIdentification extends Identification {
export interface MachinePrediction extends Identification {
algorithm: Algorithm
score: number
+ terminal: boolean
}
export class OccurrenceDetails extends Occurrence {
@@ -83,6 +86,7 @@ export class OccurrenceDetails extends Occurrence {
overridden,
taxon,
score: p.score,
+ terminal: p.terminal,
algorithm: p.algorithm,
userPermissions: p.user_permissions,
createdAt: p.created_at,
diff --git a/ui/src/data-services/models/pipeline.ts b/ui/src/data-services/models/pipeline.ts
index aa5f62701..537948b82 100644
--- a/ui/src/data-services/models/pipeline.ts
+++ b/ui/src/data-services/models/pipeline.ts
@@ -96,6 +96,11 @@ export class Pipeline {
get processingServicesOnlineLastChecked(): string | undefined {
const processingServices = this._pipeline.processing_services
+
+ if (!processingServices.length) {
+ return undefined
+ }
+
const last_checked_times = []
for (const processingService of processingServices) {
last_checked_times.push(
diff --git a/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx b/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx
index 19f8b9111..fde4ff9e7 100644
--- a/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx
+++ b/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx
@@ -49,6 +49,22 @@ export const IdentificationSummary = ({
{identification.algorithm && (