diff --git a/ami/jobs/migrations/0013_add_job_logs.py b/ami/jobs/migrations/0013_add_job_logs.py new file mode 100644 index 000000000..5a8ab2e9d --- /dev/null +++ b/ami/jobs/migrations/0013_add_job_logs.py @@ -0,0 +1,53 @@ +# Generated by Django 4.2.10 on 2024-12-17 20:01 + +import django_pydantic_field.fields +from django.db import migrations + +import ami.jobs.models + + +def migrate_logs_forward(apps, schema_editor): + """Move logs from Job.progress to Job.logs""" + Job = apps.get_model("jobs", "Job") + jobs_to_update = [] + for job in Job.objects.filter(progress__isnull=False): + if job.progress.logs or job.progress.errors: + # Move logs from progress to the new logs field + job.logs.stdout = job.progress.logs + job.logs.stderr = job.progress.errors + jobs_to_update.append(job) + # Update all jobs in a single query + Job.objects.bulk_update(jobs_to_update, ["logs"]) + + +def migrate_logs_backward(apps, schema_editor): + """Move logs from Job.logs back to Job.progress""" + Job = apps.get_model("jobs", "Job") + jobs_to_update = [] + for job in Job.objects.filter(logs__isnull=False): + # Move logs back to progress + job.progress.logs = job.logs.stdout + job.progress.errors = job.logs.stderr + jobs_to_update.append(job) + # Update all jobs in a single query + Job.objects.bulk_update(jobs_to_update, ["progress"]) + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0012_alter_job_limit"), + ] + + operations = [ + migrations.AddField( + model_name="job", + name="logs", + field=django_pydantic_field.fields.PydanticSchemaField( + config=None, default={"stderr": [], "stdout": []}, schema=ami.jobs.models.JobLogs + ), + ), + migrations.RunPython( + migrate_logs_forward, + migrate_logs_backward, + ), + ] diff --git a/ami/jobs/migrations/0014_alter_job_progress.py b/ami/jobs/migrations/0014_alter_job_progress.py new file mode 100644 index 000000000..09c9b756a --- /dev/null +++ b/ami/jobs/migrations/0014_alter_job_progress.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.10 on 2024-12-17 20:13 + +import ami.jobs.models +from django.db import migrations +import django_pydantic_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0013_add_job_logs"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="progress", + field=django_pydantic_field.fields.PydanticSchemaField( + config=None, + default={"errors": [], "logs": [], "stages": [], "summary": {"progress": 0.0, "status": "CREATED"}}, + schema=ami.jobs.models.JobProgress, + ), + ), + ] diff --git a/ami/jobs/migrations/0015_merge_20250117_2100.py b/ami/jobs/migrations/0015_merge_20250117_2100.py new file mode 100644 index 000000000..9b04fd33a --- /dev/null +++ b/ami/jobs/migrations/0015_merge_20250117_2100.py @@ -0,0 +1,12 @@ +# Generated by Django 4.2.10 on 2025-01-17 21:00 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0013_merge_0011_alter_job_limit_0012_alter_job_limit"), + ("jobs", "0014_alter_job_progress"), + ] + + operations = [] diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 3bd192106..e839c7737 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -85,14 +85,10 @@ class JobProgressSummary(pydantic.BaseModel): status: JobState = JobState.CREATED progress: float = 0 - status_label: str = "" - @pydantic.validator("status_label", always=True) - def serialize_status_label(cls, value, values) -> str: - if "status" not in values or "progress" not in values: - # Does this happen if status label gets initialized before status and progress? - return "" - return get_status_label(values["status"], values["progress"]) + @property + def status_label(self) -> str: + return get_status_label(self.status, self.progress) class Config: use_enum_values = True @@ -112,15 +108,15 @@ class JobProgress(pydantic.BaseModel): summary: JobProgressSummary stages: list[JobProgressStageDetail] - errors: list[str] = [] - logs: list[str] = [] + errors: list[str] = [] # Deprecated, @TODO remove in favor of logs.stderr + logs: list[str] = [] # Deprecated, @TODO remove in favor of logs.stdout - def get_stage_key(self, name: str) -> str: + def make_key(self, name: str) -> str: """Generate a key for a stage or param based on its name""" return python_slugify(name) def add_stage(self, name: str, key: str | None = None) -> JobProgressStageDetail: - key = key or self.get_stage_key(name) + key = key or self.make_key(name) try: return self.get_stage(key) except ValueError: @@ -146,26 +142,28 @@ def get_stage_param(self, stage_key: str, param_key: str) -> ConfigurableStagePa return param raise ValueError(f"Job stage parameter with key '{param_key}' not found in stage '{stage_key}'") - def add_stage_param(self, stage_key: str, name: str, value: typing.Any = None) -> ConfigurableStageParam: + def add_stage_param(self, stage_key: str, param_name: str, value: typing.Any = None) -> ConfigurableStageParam: stage = self.get_stage(stage_key) try: - return self.get_stage_param(stage_key, self.get_stage_key(name)) + return self.get_stage_param(stage_key, self.make_key(param_name)) except ValueError: param = ConfigurableStageParam( - name=name, - key=self.get_stage_key(name), + name=param_name, + key=self.make_key(param_name), value=value, ) stage.params.append(param) return param - def add_or_update_stage_param(self, stage_key: str, name: str, value: typing.Any = None) -> ConfigurableStageParam: + def add_or_update_stage_param( + self, stage_key: str, param_name: str, value: typing.Any = None + ) -> ConfigurableStageParam: try: - param = self.get_stage_param(stage_key, self.get_stage_key(name)) + param = self.get_stage_param(stage_key, self.make_key(param_name)) param.value = value return param except ValueError: - return self.add_stage_param(stage_key, name, value) + return self.add_stage_param(stage_key, param_name, value) def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgressStageDetail | None: """ " @@ -176,7 +174,7 @@ def update_stage(self, stage_key_or_name: str, **stage_parameters) -> JobProgres This is the preferred method to update a stage's parameters. """ - stage_key = self.get_stage_key(stage_key_or_name) # Allow both title or key to be used for lookup + stage_key = self.make_key(stage_key_or_name) # Allow both title or key to be used for lookup stage = self.get_stage(stage_key) if stage.key == stage_key: @@ -243,6 +241,11 @@ def default_ml_job_progress() -> JobProgress: ) +class JobLogs(pydantic.BaseModel): + stdout: list[str] = pydantic.Field(default_factory=list, alias="stdout", title="All messages") + stderr: list[str] = pydantic.Field(default_factory=list, alias="stderr", title="Error messages") + + class JobLogHandler(logging.Handler): """ Class for handling logs from a job and writing them to the job instance. @@ -254,24 +257,26 @@ def __init__(self, job: "Job", *args, **kwargs): self.job = job super().__init__(*args, **kwargs) - def emit(self, record): + def emit(self, record: logging.LogRecord): # Log to the current app logger logger.log(record.levelno, self.format(record)) # Write to the logs field on the job instance timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") msg = f"[{timestamp}] {record.levelname} {self.format(record)}" - if msg not in self.job.progress.logs: - self.job.progress.logs.insert(0, msg) + if msg not in self.job.logs.stdout: + self.job.logs.stdout.insert(0, msg) # Write a simpler copy of any errors to the errors field if record.levelno >= logging.ERROR: - if record.message not in self.job.progress.errors: - self.job.progress.errors.insert(0, record.message) + if record.message not in self.job.logs.stderr: + self.job.logs.stderr.insert(0, record.message) - if len(self.job.progress.logs) > self.max_log_length: - self.job.progress.logs = self.job.progress.logs[: self.max_log_length] - self.job.save() + if len(self.job.logs.stdout) > self.max_log_length: + self.job.logs.stdout = self.job.logs.stdout[: self.max_log_length] + + # @TODO consider saving logs to the database periodically rather than on every log + self.job.save(update_fields=["logs"], update_progress=False) @dataclass @@ -312,6 +317,10 @@ def run(cls, job: "Job"): job.finished_at = None job.save() + # Keep track of sub-tasks for saving results, pair with batch number + save_tasks: list[tuple[int, AsyncResult]] = [] + save_tasks_completed: list[tuple[int, AsyncResult]] = [] + if job.delay: update_interval_seconds = 2 last_update = time.time() @@ -337,91 +346,155 @@ def run(cls, job: "Job"): ) job.save() - if job.pipeline: - job.progress.update_stage( - "collect", - status=JobState.STARTED, - progress=0, - ) + if not job.pipeline: + raise ValueError("No pipeline specified to process images in ML job") - images = list( - # @TODO return generator plus image count - # @TODO pass to celery group chain? - job.pipeline.collect_images( - collection=job.source_image_collection, - deployment=job.deployment, - source_images=[job.source_image_single] if job.source_image_single else None, - job_id=job.pk, - skip_processed=True, - # shuffle=job.shuffle, - ) - ) - source_image_count = len(images) - job.progress.update_stage("collect", total_images=source_image_count) - - if job.shuffle and source_image_count > 1: - job.logger.info("Shuffling images") - random.shuffle(images) - - if job.limit and source_image_count > job.limit: - job.logger.warn(f"Limiting number of images to {job.limit} (out of {source_image_count})") - images = images[: job.limit] - image_count = len(images) - job.progress.add_stage_param("collect", "Limit", image_count) - else: - image_count = source_image_count + job.progress.update_stage( + "collect", + status=JobState.STARTED, + progress=0, + ) - job.progress.update_stage( - "collect", - status=JobState.SUCCESS, - progress=1, + images = list( + # @TODO return generator plus image count + # @TODO pass to celery group chain? + job.pipeline.collect_images( + collection=job.source_image_collection, + deployment=job.deployment, + source_images=[job.source_image_single] if job.source_image_single else None, + job_id=job.pk, + skip_processed=True, + # shuffle=job.shuffle, ) + ) + source_image_count = len(images) + job.progress.update_stage("collect", total_images=source_image_count) + + if job.shuffle and source_image_count > 1: + job.logger.info("Shuffling images") + random.shuffle(images) + + if job.limit and source_image_count > job.limit: + job.logger.warn(f"Limiting number of images to {job.limit} (out of {source_image_count})") + images = images[: job.limit] + image_count = len(images) + job.progress.add_stage_param("collect", "Limit", image_count) + else: + image_count = source_image_count - total_detections = 0 - total_classifications = 0 + job.progress.update_stage( + "collect", + status=JobState.SUCCESS, + progress=1, + ) - CHUNK_SIZE = 2 # Keep it low to see more progress updates - chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa + # End image collection stage + job.save() - for i, chunk in enumerate(chunks): - try: - results = job.pipeline.process_images( - images=chunk, - job_id=job.pk, - ) - except Exception as e: - # Log error about image batch and continue - job.logger.error(f"Failed to process image batch {i} of {len(chunks)}: {e}") - continue + total_captures = 0 + total_detections = 0 + total_classifications = 0 + CHUNK_SIZE = 4 # Keep it low to see more progress updates + chunks = [images[i : i + CHUNK_SIZE] for i in range(0, image_count, CHUNK_SIZE)] # noqa + request_failed_images = [] + + for i, chunk in enumerate(chunks): + request_sent = time.time() + job.logger.info(f"Processing image batch {i+1} of {len(chunks)}") + try: + results = job.pipeline.process_images( + images=chunk, + job_id=job.pk, + ) + job.logger.info(f"Processed image batch {i+1} in {time.time() - request_sent:.2f}s") + except Exception as e: + # Log error about image batch and continue + job.logger.error(f"Failed to process image batch {i+1}: {e}") + request_failed_images.extend([img.pk for img in chunk]) + else: + total_captures += len(results.source_images) total_detections += len(results.detections) total_classifications += len([c for d in results.detections for c in d.classifications]) - job.progress.update_stage( - "process", - status=JobState.STARTED, - progress=(i + 1) / len(chunks), - processed=(i + 1) * CHUNK_SIZE, - remaining=image_count - (i + 1) * CHUNK_SIZE, - detections=total_detections, - classifications=total_classifications, - ) - job.save() if results.source_images or results.detections: - save_results_task = job.pipeline.save_results_async(results=results, job_id=job.pk) - job.logger.info(f"Saving results in sub-task {save_results_task.id}") + # @TODO add callback to report errors while saving results marking the job as failed + save_results_task: AsyncResult = job.pipeline.save_results_async(results=results, job_id=job.pk) + save_tasks.append((i + 1, save_results_task)) + job.logger.info(f"Saving results for batch {i+1} in sub-task {save_results_task.id}") job.progress.update_stage( "process", - status=JobState.SUCCESS, + status=JobState.STARTED, + progress=(i + 1) / len(chunks), + processed=min((i + 1) * CHUNK_SIZE, image_count), + failed=len(request_failed_images), + remaining=max(image_count - ((i + 1) * CHUNK_SIZE), 0), ) + + # count the completed, successful, and failed save_tasks: + save_tasks_completed = [t for t in save_tasks if t[1].ready()] + failed_save_tasks = [t for t in save_tasks_completed if not t[1].successful()] + + for failed_batch_num, failed_task in failed_save_tasks: + # First log all errors and update the job status. Then raise exception if any failed. + job.logger.error(f"Failed to save results from batch {failed_batch_num} (sub-task {failed_task.id})") + job.progress.update_stage( "results", - status=JobState.SUCCESS, + status=JobState.FAILURE if failed_save_tasks else JobState.STARTED, + progress=len(save_tasks_completed) / len(chunks), + captures=total_captures, + detections=total_detections, + classifications=total_classifications, ) + job.save() + + # Stop processing if any save tasks have failed + # Otherwise, calculate the percent of images that have failed to save + throw_on_save_error = True + for failed_batch_num, failed_task in failed_save_tasks: + if throw_on_save_error: + failed_task.maybe_throw() + + percent_successful = 1 - len(request_failed_images) / image_count if image_count else 0 + job.logger.info(f"Processed {percent_successful:.0%} of images successfully.") + + # Check all Celery sub-tasks if they have completed saving results + save_tasks_remaining = set(save_tasks) - set(save_tasks_completed) + job.logger.info( + f"Checking the status of {len(save_tasks_remaining)} remaining sub-tasks that are still saving results." + ) + for batch_num, sub_task in save_tasks: + if not sub_task.ready(): + job.logger.info(f"Waiting for batch {batch_num} to finish saving results (sub-task {sub_task.id})") + # @TODO this is not recommended! Use a group or chain. But we need to refactor. + # https://docs.celeryq.dev/en/latest/userguide/tasks.html#avoid-launching-synchronous-subtasks + sub_task.wait(disable_sync_subtasks=False, timeout=60) + if not sub_task.successful(): + error: Exception = sub_task.result + job.logger.error(f"Failed to save results from batch {batch_num}! (sub-task {sub_task.id}): {error}") + sub_task.maybe_throw() + + job.logger.info(f"All tasks completed for job {job.pk}") + + FAILURE_THRESHOLD = 0.5 + if image_count and (percent_successful < FAILURE_THRESHOLD): + job.progress.update_stage("process", status=JobState.FAILURE) + job.save() + raise Exception(f"Failed to process more than {int(FAILURE_THRESHOLD * 100)}% of images") - job.update_status(JobState.SUCCESS) - job.update_progress() + job.progress.update_stage( + "process", + status=JobState.SUCCESS, + progress=1, + ) + job.progress.update_stage( + "results", + status=JobState.SUCCESS, + progress=1, + ) + job.update_status(JobState.SUCCESS, save=False) job.finished_at = datetime.datetime.now() job.save() @@ -445,7 +518,9 @@ def run(cls, job: "Job"): job.finished_at = None job.save() - if job.deployment: + if not job.deployment: + raise ValueError("No deployment provided for data storage sync job") + else: job.logger.info(f"Syncing captures for deployment {job.deployment}") job.progress.update_stage( cls.key, @@ -465,10 +540,7 @@ def run(cls, job: "Job"): ) job.update_status(JobState.SUCCESS) job.save() - else: - job.update_status(JobState.FAILURE) - job.update_progress() job.finished_at = datetime.datetime.now() job.save() @@ -492,11 +564,7 @@ def run(cls, job: "Job"): job.save() if not job.source_image_collection: - job.logger.error("No source image collection provided") - job.update_status(JobState.FAILURE) - job.finished_at = datetime.datetime.now() - job.save() - return + raise ValueError("No source image collection provided") job.logger.info(f"Populating source image collection {job.source_image_collection}") job.update_status(JobState.STARTED) @@ -508,7 +576,7 @@ def run(cls, job: "Job"): progress=0.10, captures_added=0, ) - job.update_progress(save=True) + job.save() job.source_image_collection.populate_sample(job=job) job.logger.info(f"Finished populating source image collection {job.source_image_collection}") @@ -525,7 +593,6 @@ def run(cls, job: "Job"): ) job.finished_at = datetime.datetime.now() job.update_status(JobState.SUCCESS, save=False) - job.update_progress(save=False) job.save() @@ -535,9 +602,7 @@ class UnknownJobType(JobType): @classmethod def run(cls, job: "Job"): - job.logger.error(f"Unknown job type '{job.job_type()}'") - job.update_status(JobState.UNKNOWN) - job.save() + raise ValueError(f"Unknown job type '{job.job_type()}'") VALID_JOB_TYPES = [MLJob, SourceImageCollectionPopulateJob, DataStorageSyncJob, UnknownJobType] @@ -580,6 +645,7 @@ class Job(BaseModel): # @TODO can we use an Enum or Pydantic model for status? status = models.CharField(max_length=255, default=JobState.CREATED.name, choices=JobState.choices()) progress: JobProgress = SchemaField(JobProgress, default=default_job_progress()) + logs: JobLogs = SchemaField(JobLogs, default=JobLogs()) result = models.JSONField(null=True, blank=True) task_id = models.CharField(max_length=255, null=True, blank=True) delay = models.IntegerField("Delay in seconds", default=0, help_text="Delay before running the job") @@ -676,11 +742,12 @@ def setup(self, save=True): pipeline_stage = self.progress.add_stage("Process") self.progress.add_stage_param(pipeline_stage.key, "Processed", "") self.progress.add_stage_param(pipeline_stage.key, "Remaining", "") - self.progress.add_stage_param(pipeline_stage.key, "Detections", "") - self.progress.add_stage_param(pipeline_stage.key, "Classifications", "") + self.progress.add_stage_param(pipeline_stage.key, "Failed", "") saving_stage = self.progress.add_stage("Results") - self.progress.add_stage_param(saving_stage.key, "Objects created", "") + self.progress.add_stage_param(saving_stage.key, "Captures", "") + self.progress.add_stage_param(saving_stage.key, "Detections", "") + self.progress.add_stage_param(saving_stage.key, "Classifications", "") if save: self.save() @@ -700,6 +767,7 @@ def retry(self, async_task=True): Retry the job. """ self.logger.info(f"Re-running job {self}") + self.finished_at = None self.progress.reset() self.status = JobState.RETRY self.save() @@ -733,11 +801,11 @@ def update_status(self, status=None, save=True): status = task.status if not status: - self.logger.warn(f"Could not determine status of job {self.pk}") + self.logger.warning(f"Could not determine status of job {self.pk}") return if status != self.status: - self.logger.info(f"Changing status of job {self.pk} to {status}") + self.logger.info(f"Changing status of job {self.pk} from {self.status} to {status}") self.status = status self.progress.summary.status = status @@ -757,10 +825,10 @@ def update_progress(self, save=True): if stage.progress > 0 and stage.status == JobState.CREATED: # Update any stages that have started but are still in the CREATED state stage.status = JobState.STARTED - elif stage.status == JobState.SUCCESS and stage.progress < 1: + elif stage.status in JobState.final_states() and stage.progress < 1: # Update any stages that are complete but have a progress less than 1 stage.progress = 1 - elif stage.progress == 1 and stage.status == JobState.STARTED: + elif stage.progress == 1 and stage.status not in JobState.final_states(): # Update any stages that are complete but are still in the STARTED state stage.status = JobState.SUCCESS total_progress = sum([stage.progress for stage in self.progress.stages]) / len(self.progress.stages) @@ -768,18 +836,18 @@ def update_progress(self, save=True): self.progress.summary.progress = total_progress if save: - self.save() + self.save(update_progress=False) def duration(self) -> datetime.timedelta | None: if self.started_at and self.finished_at: return self.finished_at - self.started_at return None - def save(self, *args, **kwargs): + def save(self, update_progress=True, *args, **kwargs): """ Create the job stages if they don't exist. """ - if self.pk and self.progress.stages: + if self.pk and self.progress.stages and update_progress: self.update_progress(save=False) else: self.setup(save=False) diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index ab07ef320..f1fef491d 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -11,7 +11,7 @@ from ami.ml.models import Pipeline from ami.ml.serializers import PipelineNestedSerializer -from .models import Job, JobProgress +from .models import Job, JobLogs, JobProgress, MLJob class JobProjectNestedSerializer(DefaultSerializer): @@ -37,7 +37,11 @@ class JobListSerializer(DefaultSerializer): source_image_collection = SourceImageCollectionNestedSerializer(read_only=True) source_image_single = SourceImageNestedSerializer(read_only=True) progress = SchemaField(schema=JobProgress, read_only=True) + logs = SchemaField(schema=JobLogs, read_only=True) job_type = JobTypeSerializer(read_only=True) + # All jobs created from the Jobs UI are ML jobs (datasync, etc. are created for the user) + # @TODO Remove this when the UI is updated pass a job type. This should be a required field. + job_type_key = serializers.SlugField(write_only=True, default=MLJob.key) project_id = serializers.PrimaryKeyRelatedField( label="Project", @@ -109,7 +113,9 @@ class Meta: "finished_at", "duration", "progress", + "logs", "job_type", + "job_type_key", # "duration", # "duration_label", # "progress_label", diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 4c9d41447..59d3cfd59 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -100,8 +100,10 @@ def test_create_job_unauthenticated(self): jobs_create_url = reverse_with_params("api:job-list") job_data = { "project_id": self.project.pk, + "source_image_collection_id": self.source_image_collection.pk, "name": "Test job unauthenticated", "delay": 0, + "job_type_key": SourceImageCollectionPopulateJob.key, } self.client.force_authenticate(user=None) resp = self.client.post(jobs_create_url, job_data) @@ -113,9 +115,10 @@ def _create_job(self, name: str, start_now: bool = True): job_data = { "project_id": self.job.project.pk, "name": name, - "collection_id": self.source_image_collection.pk, + "source_image_collection_id": self.source_image_collection.pk, "delay": 0, "start_now": start_now, + "job_type_key": SourceImageCollectionPopulateJob.key, } resp = self.client.post(jobs_create_url, job_data) self.client.force_authenticate(user=None) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index e783ff9a5..49a87afc0 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -11,7 +11,7 @@ from ami.utils.fields import url_boolean_param from ami.utils.requests import get_active_project, project_id_doc_param -from .models import Job, JobState, MLJob +from .models import Job, JobState from .serializers import JobListSerializer, JobSerializer logger = logging.getLogger(__name__) @@ -117,11 +117,6 @@ def perform_create(self, serializer): If the ``start_now`` parameter is passed, enqueue the job immediately. """ - # All jobs created from the Jobs UI are ML jobs. - # @TODO Remove this when the UI is updated pass a job type - if not serializer.validated_data.get("job_type_key"): - serializer.validated_data["job_type_key"] = MLJob.key - job: Job = serializer.save() # type: ignore if url_boolean_param(self.request, "start_now", default=False): # job.run() diff --git a/ami/main/admin.py b/ami/main/admin.py index 19184a6fe..9429bf45c 100644 --- a/ami/main/admin.py +++ b/ami/main/admin.py @@ -13,6 +13,7 @@ from .models import ( BlogPost, + Classification, Deployment, Device, Event, @@ -236,11 +237,80 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: class OccurrenceAdmin(admin.ModelAdmin[Occurrence]): """Admin panel example for ``Occurrence`` model.""" - list_display = ("id", "determination", "project", "deployment", "event") + list_display = ( + "id", + "determination", + "project", + "deployment", + "event", + "detections_count", + "created_at", + "updated_at", + ) + + def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: + qs = super().get_queryset(request) + qs = qs.select_related("determination", "project", "deployment", "event") + # Add detections count to queryset + qs = qs.annotate(detections_count=models.Count("detections")) + # Add min, max and avg detection__classifications counts to queryset + # qs = qs.annotate( + # min_detection_classifications=models.Min("detections__classifications"), + # max_detection_classifications=models.Max("detections__classifications"), + # avg_detection_classifications=models.Avg("detections__classifications"), + # ) + return qs + + @admin.display( + description="Detections", + ordering="detections_count", + ) + def detections_count(self, obj) -> int: + return obj.detections_count + + ordering = ("-created_at",) + + +@admin.register(Classification) +class ClassificationAdmin(admin.ModelAdmin[Classification]): + list_display = ( + "__str__", + "taxon", + "algorithm", + "num_scores", + "num_logits", + "detection_date", + "timestamp", + "terminal", + "created_at", + ) + + list_filter = ( + "algorithm", + "terminal", + "created_at", + "detection__source_image__project", + "taxon__rank", + ) def get_queryset(self, request: HttpRequest) -> QuerySet[Any]: qs = super().get_queryset(request) - return qs.select_related("determination", "project", "deployment", "event") + return qs.select_related( + "taxon", "detection", "detection__source_image", "detection__source_image__project" + ).annotate( + detection_date=models.F("detection__timestamp"), + ) + + @admin.display() + def detection_date(self, obj: Classification) -> str: + # This property comes from the annotation in get_queryset, not the model + return obj.detection_date # type: ignore + + def num_scores(self, obj: Classification) -> int: + return len(obj.scores) if obj.scores else 0 + + def num_logits(self, obj: Classification) -> int: + return len(obj.logits) if obj.logits else 0 class TaxonParentFilter(admin.SimpleListFilter): diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index b37221aa2..9ba46c27c 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -451,6 +451,7 @@ class Meta: "name", "rank", "details", + "gbif_taxon_key", ] @@ -720,6 +721,7 @@ class Meta: "detections_count", "events_count", "occurrences", + "gbif_taxon_key", ] @@ -738,10 +740,50 @@ class Meta: ] +class ClassificationPredictionItemSerializer(serializers.Serializer): + taxon = TaxonNestedSerializer(read_only=True) + score = serializers.FloatField(read_only=True) + logit = serializers.FloatField(read_only=True) + + class ClassificationSerializer(DefaultSerializer): taxon = TaxonNestedSerializer(read_only=True) algorithm = AlgorithmSerializer(read_only=True) + top_n = ClassificationPredictionItemSerializer(many=True, read_only=True) + + class Meta: + model = Classification + fields = [ + "id", + "details", + "taxon", + "score", + "algorithm", + "scores", + "logits", + "top_n", + "created_at", + "updated_at", + ] + + +class ClassificationWithTaxaSerializer(ClassificationSerializer): + """ + Return all possible taxa objects in the category map with the classification. + + This is slow for large category maps. + It's recommended to retrieve and cache the category map with taxa ahead of time. + """ + taxa = TaxonNestedSerializer(many=True, read_only=True) + + class Meta(ClassificationSerializer.Meta): + fields = ClassificationSerializer.Meta.fields + [ + "taxa", + ] + + +class ClassificationListSerializer(DefaultSerializer): class Meta: model = Classification fields = [ @@ -751,11 +793,22 @@ class Meta: "score", "algorithm", "created_at", + "updated_at", ] -class OccurrenceClassificationSerializer(ClassificationSerializer): - pass +class ClassificationNestedSerializer(ClassificationSerializer): + class Meta: + model = Classification + fields = [ + "id", + "details", + "taxon", + "score", + "terminal", + "algorithm", + "created_at", + ] class CaptureDetectionsSerializer(DefaultSerializer): @@ -767,6 +820,7 @@ class Meta: # queryset = Detection.objects.prefetch_related("classifications") fields = [ "id", + "details", "url", "width", "height", @@ -800,7 +854,7 @@ class Meta: class DetectionNestedSerializer(DefaultSerializer): - classifications = ClassificationSerializer(many=True, read_only=True) + classifications = ClassificationNestedSerializer(many=True, read_only=True) capture = DetectionCaptureNestedSerializer(read_only=True, source="source_image") class Meta: @@ -808,6 +862,7 @@ class Meta: # queryset = Detection.objects.prefetch_related("classifications") fields = [ "id", + "details", "timestamp", "url", "capture", @@ -842,6 +897,7 @@ class DetectionSerializer(DefaultSerializer): detection_algorithm_id = serializers.PrimaryKeyRelatedField( queryset=Algorithm.objects.all(), source="detection_algorithm", write_only=True ) + classifications = ClassificationNestedSerializer(many=True, read_only=True) class Meta: model = Detection @@ -849,6 +905,7 @@ class Meta: "source_image", "detection_algorithm", "detection_algorithm_id", + "classifications", ] @@ -1077,6 +1134,7 @@ class Meta: "determination_details", "identifications", "created_at", + "updated_at", ] def get_determination_details(self, obj: Occurrence): @@ -1099,7 +1157,7 @@ def get_determination_details(self, obj: Occurrence): if identification or not obj.best_prediction: prediction = None else: - prediction = OccurrenceClassificationSerializer(obj.best_prediction, context=context).data + prediction = ClassificationNestedSerializer(obj.best_prediction, context=context).data return dict( taxon=taxon, @@ -1112,7 +1170,8 @@ def get_determination_details(self, obj: Occurrence): class OccurrenceSerializer(OccurrenceListSerializer): determination = CaptureTaxonSerializer(read_only=True) detections = DetectionNestedSerializer(many=True, read_only=True) - predictions = OccurrenceClassificationSerializer(many=True, read_only=True) + identifications = OccurrenceIdentificationSerializer(many=True, read_only=True) + predictions = ClassificationNestedSerializer(many=True, read_only=True) deployment = DeploymentNestedSerializer(read_only=True) event = EventNestedSerializer(read_only=True) # first_appearance = TaxonSourceImageNestedSerializer(read_only=True) diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 724eebee9..74fd8860e 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -49,7 +49,9 @@ update_detection_counts, ) from .serializers import ( + ClassificationListSerializer, ClassificationSerializer, + ClassificationWithTaxaSerializer, DeploymentListSerializer, DeploymentSerializer, DetectionListSerializer, @@ -429,7 +431,13 @@ class SourceImageViewSet(DefaultViewSet): queryset = SourceImage.objects.all() serializer_class = SourceImageSerializer - filterset_fields = ["event", "deployment", "deployment__project", "collections"] + filterset_fields = [ + "event", + "deployment", + "deployment__project", + "collections", + "project", + ] ordering_fields = [ "created_at", "updated_at", @@ -732,9 +740,9 @@ class DetectionViewSet(DefaultViewSet): API endpoint that allows detections to be viewed or edited. """ - queryset = Detection.objects.all() + queryset = Detection.objects.all().select_related("source_image", "detection_algorithm") serializer_class = DetectionSerializer - filterset_fields = ["source_image", "detection_algorithm"] + filterset_fields = ["source_image", "detection_algorithm", "source_image__project"] ordering_fields = ["created_at", "updated_at", "detection_score", "timestamp"] def get_serializer_class(self): @@ -1262,15 +1270,34 @@ class ClassificationViewSet(DefaultViewSet): API endpoint for viewing and adding classification results from a model. """ - queryset = Classification.objects.all() + queryset = Classification.objects.all() # .select_related("taxon", "algorithm", "detection") serializer_class = ClassificationSerializer - filterset_fields = ["detection", "detection__occurrence", "taxon", "algorithm"] + filterset_fields = [ + "detection", + "detection__occurrence", + "taxon", + "algorithm", + "detection__source_image", + "detection__source_image__project", + ] ordering_fields = [ "created_at", "updated_at", "score", ] + def get_serializer_class(self): + """ + Return a different serializer for list and detail views. + If "with_taxa" is in the query params, return a different serializer. + """ + if self.action == "list": + return ClassificationListSerializer + elif "with_taxa" in self.request.query_params: + return ClassificationWithTaxaSerializer + else: + return ClassificationSerializer + class SummaryView(GenericAPIView): permission_classes = [IsActiveStaffOrReadOnly] diff --git a/ami/main/charts.py b/ami/main/charts.py index 352e2e2ac..4d9809817 100644 --- a/ami/main/charts.py +++ b/ami/main/charts.py @@ -199,7 +199,7 @@ def detections_per_hour(project_pk: int): def occurrences_accumulated(project_pk: int): - # Line chart of the accumulated number of occurrnces over time throughout the season + # Line chart of the accumulated number of occurrences over time throughout the season Occurrence = apps.get_model("main", "Occurrence") occurrences_per_day = ( @@ -212,7 +212,8 @@ def occurrences_accumulated(project_pk: int): .order_by("event__start") ) - if occurrences_per_day.count(): + occurrences_exist = Occurrence.objects.filter(project=project_pk).exists() + if occurrences_exist: days, counts = list(zip(*occurrences_per_day)) # Accumulate the counts counts = list(itertools.accumulate(counts)) diff --git a/ami/main/migrations/0039_remove_classification_raw_output_and_more.py b/ami/main/migrations/0039_remove_classification_raw_output_and_more.py new file mode 100644 index 000000000..50d8cf7d0 --- /dev/null +++ b/ami/main/migrations/0039_remove_classification_raw_output_and_more.py @@ -0,0 +1,54 @@ +# Generated by Django 4.2.10 on 2024-11-26 18:52 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0007_algorithmcategorymap_algorithm_category_map"), + ("main", "0038_alter_detection_path_alter_sourceimage_event_and_more"), + ] + + operations = [ + migrations.RemoveField( + model_name="classification", + name="raw_output", + ), + migrations.RemoveField( + model_name="classification", + name="softmax_output", + ), + migrations.AddField( + model_name="classification", + name="category_map", + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.PROTECT, to="ml.algorithmcategorymap" + ), + ), + migrations.AddField( + model_name="classification", + name="logits", + field=django.contrib.postgres.fields.ArrayField(base_field=models.FloatField(), null=True, size=None), + ), + migrations.AddField( + model_name="classification", + name="scores", + field=django.contrib.postgres.fields.ArrayField(base_field=models.FloatField(), null=True, size=None), + ), + migrations.AddField( + model_name="classification", + name="terminal", + field=models.BooleanField( + default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?" + ), + ), + migrations.AddField( + model_name="taxon", + name="search_names", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(max_length=255), blank=True, null=True, size=None + ), + ), + ] diff --git a/ami/main/migrations/0040_alter_classification_logits_and_more.py b/ami/main/migrations/0040_alter_classification_logits_and_more.py new file mode 100644 index 000000000..8a2aff79f --- /dev/null +++ b/ami/main/migrations/0040_alter_classification_logits_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.10 on 2024-12-05 01:49 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0039_remove_classification_raw_output_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="classification", + name="logits", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), + help_text="The raw output of the last fully connected layer of the model", + null=True, + size=None, + ), + ), + migrations.AlterField( + model_name="classification", + name="scores", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), + help_text="The probabilities the model, calibrated by the model maker, likely the softmax output", + null=True, + size=None, + ), + ), + ] diff --git a/ami/main/migrations/0044_merge_20250124_2333.py b/ami/main/migrations/0044_merge_20250124_2333.py new file mode 100644 index 000000000..1cae6da87 --- /dev/null +++ b/ami/main/migrations/0044_merge_20250124_2333.py @@ -0,0 +1,12 @@ +# Generated by Django 4.2.10 on 2025-01-24 23:33 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0039_project_users_squashed_0043_rename_users_project_members"), + ("main", "0040_alter_classification_logits_and_more"), + ] + + operations = [] diff --git a/ami/main/models.py b/ami/main/models.py index a77b39335..41f887ebf 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -12,6 +12,7 @@ import pydantic from django.apps import apps from django.conf import settings +from django.contrib.postgres.fields import ArrayField from django.core.exceptions import ValidationError from django.core.files.storage import default_storage from django.db import IntegrityError, models @@ -153,7 +154,7 @@ def summary_data(self): plots.append(charts.captures_per_hour(project_pk=self.pk)) if self.occurrences.exists(): plots.append(charts.detections_per_hour(project_pk=self.pk)) - plots.append(charts.occurrences_accumulated(project_pk=self.pk)) + # plots.append(charts.occurrences_accumulated(project_pk=self.pk)) else: plots.append(charts.events_per_month(project_pk=self.pk)) # plots.append(charts.captures_per_month(project_pk=self.pk)) @@ -1248,6 +1249,7 @@ class SourceImage(BaseModel): blank=True, ) + event_id: int | None detections: models.QuerySet["Detection"] collections: models.QuerySet["SourceImageCollection"] jobs: models.QuerySet["Job"] @@ -1782,27 +1784,27 @@ class Classification(BaseModel): related_name="classifications", ) - # occurrence = models.ForeignKey( - # "Occurrence", - # on_delete=models.SET_NULL, - # null=True, - # related_name="predictions", - # ) - taxon = models.ForeignKey("Taxon", on_delete=models.SET_NULL, null=True, related_name="classifications") score = models.FloatField(null=True) timestamp = models.DateTimeField() - # terminal = models.BooleanField( - # default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?" - # ) - - softmax_output = models.JSONField(null=True) # scores for all classes - raw_output = models.JSONField(null=True) # raw output from the model + terminal = models.BooleanField( + default=True, help_text="Is this the final classification from a series of classifiers in a pipeline?" + ) + logits = ArrayField( + models.FloatField(), null=True, help_text="The raw output of the last fully connected layer of the model" + ) + scores = ArrayField( + models.FloatField(), + null=True, + help_text="The probabilities the model, calibrated by the model maker, likely the softmax output", + ) + category_map = models.ForeignKey("ml.AlgorithmCategoryMap", on_delete=models.PROTECT, null=True) algorithm = models.ForeignKey( "ml.Algorithm", on_delete=models.SET_NULL, null=True, + related_name="classifications", ) # job = models.CharField(max_length=255, null=True) @@ -1816,7 +1818,91 @@ class Meta: ordering = ["-created_at", "-score"] def __str__(self) -> str: - return f"#{self.pk} to Taxon #{self.taxon_id} ({self.score:.2f}) by Algorithm #{self.algorithm_id}" + terminal = "Terminal" if self.terminal else "Intermediate" + if logger.getEffectiveLevel() == logging.DEBUG: + # Query the related objects to get the names + return f"#{self.pk} to Taxon {self.taxon} ({self.score:.2f}) by Algorithm {self.algorithm} ({terminal})" + return ( + f"#{self.pk} to Taxon #{self.taxon_id} ({self.score:.2f}) by Algorithm #{self.algorithm_id} ({terminal})" + ) + + def top_scores_with_index(self, n: int | None = None) -> typing.Iterable[tuple[int, float]]: + """ + Return the scores with their index, but sorted by score. + """ + if self.scores: + top_scores_by_index = sorted(enumerate(self.scores), key=lambda x: x[1], reverse=True)[:n] + return top_scores_by_index + else: + return [] + + def predictions(self, sort=True) -> typing.Iterable[tuple[str, float]]: + """ + Return all label-score pairs for this classification using the category map. + """ + if not self.category_map: + raise ValueError("Classification must have a category map to get predictions.") + scores = self.scores or [] + preds = zip(self.category_map.labels, scores) + if sort: + return sorted(preds, key=lambda x: x[1], reverse=True) + else: + return preds + + def predictions_with_taxa(self, sort=True) -> typing.Iterable[tuple["Taxon", float]]: + """ + Return taxa objects and their scores for this classification using the category map. + + @TODO make this more efficient with numpy and/or postgres array functions. especially when we only need + the top N out of thousands of taxa. + """ + if not self.category_map: + raise ValueError("Classification must have a category map to get predictions.") + scores = self.scores or [] + category_data_with_taxa = self.category_map.with_taxa() + taxa_sorted_by_index = [cat["taxon"] for cat in sorted(category_data_with_taxa, key=lambda cat: cat["index"])] + preds = zip(taxa_sorted_by_index, scores) + if sort: + return sorted(preds, key=lambda x: x[1], reverse=True) + else: + return preds + + def taxa(self) -> typing.Iterable["Taxon"]: + """ + Return the taxa objects for this classification using the category map. + """ + if not self.category_map: + return [] + category_data_with_taxa = self.category_map.with_taxa() + taxa_sorted_by_index = [cat["taxon"] for cat in sorted(category_data_with_taxa, key=lambda cat: cat["index"])] + return taxa_sorted_by_index + + def top_n(self, n: int = 3) -> list[dict[str, "Taxon | float | None"]]: + """Return top N taxa and scores for this classification.""" + if not self.category_map: + raise ValueError("Classification must have a category map to get top N.") + + top_scored = self.top_scores_with_index(n) # (index, score) pairs + indexes = [idx for idx, _ in top_scored] + category_data = self.category_map.with_taxa(only_indexes=indexes) + index_to_taxon = {cat["index"]: cat["taxon"] for cat in category_data} + + return [ + { + "taxon": index_to_taxon[i], + "score": s, + "logit": self.logits[i] if self.logits else None, + } + for i, s in top_scored + ] + + def save(self, *args, **kwargs): + """ + Set the category map based on the algorithm. + """ + if self.algorithm and not self.category_map: + self.category_map = self.algorithm.category_map + super().save(*args, **kwargs) @final @@ -2107,10 +2193,22 @@ def best_detection(self): @functools.cached_property def best_prediction(self): - return self.predictions().first() + """ + Use the best prediction as the best identification if there are no human identifications. + + Uses the highest scoring classification (from any algorithm) as the best prediction. + Considers terminal classifications first, then non-terminal ones. + (Terminal classifications are the final classifications of a pipeline, non-terminal are intermediate models.) + """ + return self.predictions().order_by("-terminal", "-score").first() @functools.cached_property def best_identification(self): + """ + The most recent human identification is used as the best identification. + + @TODO this could use a confidence level chosen manually by the users/experts. + """ return Identification.objects.filter(occurrence=self, withdrawn=False).order_by("-created_at").first() def get_determination_score(self) -> float | None: @@ -2163,6 +2261,7 @@ def save(self, update_determination=True, *args, **kwargs): if self.determination and not self.determination_score: # This may happen for legacy occurrences that were created # before the determination_score field was added + # @TODO remove self.determination_score = self.get_determination_score() if not self.determination_score: logger.warning(f"Could not determine score for {self}") @@ -2175,16 +2274,16 @@ class Meta: def update_occurrence_determination( occurrence: Occurrence, current_determination: typing.Optional["Taxon"] = None, save=True -): +) -> bool: """ Update the determination of the occurrence based on the identifications & predictions. If there are identifications, set the determination to the latest identification. If there are no identifications, set the determination to the top prediction. - The `current_determination` is the determination curently saved in the database. + The `current_determination` is the determination currently saved in the database. The `occurrence` object may already have a different un-saved determination set - so it is neccessary to retrieve the current determination from the database, but + so it is necessary to retrieve the current determination from the database, but this can also be passed in as an argument to avoid an extra database query. @TODO Add tests for this important method! @@ -2196,6 +2295,8 @@ def update_occurrence_determination( del occurrence.best_identification if hasattr(occurrence, "best_prediction"): del occurrence.best_prediction + if hasattr(occurrence, "best_identification"): + del occurrence.best_identification current_determination = ( current_determination @@ -2217,18 +2318,29 @@ def update_occurrence_determination( new_score = top_prediction.score if new_determination and new_determination != current_determination: - logger.info(f"Changing det. of {occurrence} from {current_determination} to {new_determination}") + logger.debug(f"Changing det. of {occurrence} from {current_determination} to {new_determination}") occurrence.determination = new_determination needs_update = True if new_score and new_score != occurrence.determination_score: - logger.info(f"Changing det. score of {occurrence} from {occurrence.determination_score} to {new_score}") + logger.debug(f"Changing det. score of {occurrence} from {occurrence.determination_score} to {new_score}") occurrence.determination_score = new_score needs_update = True + if not needs_update: + if logger.getEffectiveLevel() <= logging.DEBUG: + all_predictions = occurrence.predictions() + all_preds_print = ", ".join([str(p) for p in all_predictions]) + logger.debug( + f"No update needed for determination of {occurrence}. Best prediction: {occurrence.best_prediction}. " + f"All preds: {all_preds_print}" + ) + if save and needs_update: occurrence.save(update_determination=False) + return needs_update + @final class TaxaManager(models.Manager): @@ -2457,6 +2569,7 @@ class Taxon(BaseModel): active = models.BooleanField(default=True) synonym_of = models.ForeignKey("self", on_delete=models.SET_NULL, null=True, blank=True, related_name="synonyms") + search_names = ArrayField(models.CharField(max_length=255), null=True, blank=True) gbif_taxon_key = models.BigIntegerField("GBIF taxon key", blank=True, null=True) bold_taxon_bin = models.CharField("BOLD taxon BIN", max_length=255, blank=True, null=True) inat_taxon_id = models.BigIntegerField("iNaturalist taxon ID", blank=True, null=True) diff --git a/ami/ml/admin.py b/ami/ml/admin.py index 3f4784d1b..54d6f0185 100644 --- a/ami/ml/admin.py +++ b/ami/ml/admin.py @@ -2,7 +2,7 @@ from ami.main.admin import AdminBase -from .models.algorithm import Algorithm +from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline from .models.processing_service import ProcessingService @@ -11,8 +11,10 @@ class AlgorithmAdmin(AdminBase): list_display = [ "name", + "key", "version", "version_name", + "task_type", "created_at", "updated_at", ] @@ -26,6 +28,7 @@ class AlgorithmAdmin(AdminBase): ] list_filter = [ "pipelines", + "task_type", ] @@ -68,3 +71,33 @@ class ProcessingServiceAdmin(AdminBase): "endpoint_url", "created_at", ] + + +@admin.register(AlgorithmCategoryMap) +class AlgorithmCategoryMapAdmin(AdminBase): + list_display = [ + "version", + "uri", + "created_at", + "num_data_items", + "num_labels", + ] + search_fields = [ + "version", + ] + ordering = [ + "version", + ] + list_filter = [ + "algorithms", + ] + formfield_overrides = { + # See https://pypi.org/project/django-json-widget/ + # models.JSONField: {"widget": JSONInput}, + } + + def num_data_items(self, obj): + return len(obj.data) if obj.data else 0 + + def num_labels(self, obj): + return len(obj.labels) if obj.labels else 0 diff --git a/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py b/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py new file mode 100644 index 000000000..6dd501558 --- /dev/null +++ b/ami/ml/migrations/0007_algorithmcategorymap_algorithm_category_map.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.10 on 2024-11-26 18:52 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0006_alter_pipeline_endpoint_url_alter_pipeline_projects"), + ] + + operations = [ + migrations.CreateModel( + name="AlgorithmCategoryMap", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("data", models.JSONField()), + ( + "labels", + django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(max_length=255), default=list, size=None + ), + ), + ("version", models.CharField(blank=True, max_length=255, null=True)), + ("url", models.URLField(blank=True, null=True)), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="algorithm", + name="category_map", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + related_name="algorithms", + to="ml.algorithmcategorymap", + ), + ), + ] diff --git a/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py b/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py new file mode 100644 index 000000000..a51bbb5cc --- /dev/null +++ b/ami/ml/migrations/0008_algorithmcategorymap_description_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.10 on 2024-12-05 00:24 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0007_algorithmcategorymap_algorithm_category_map"), + ] + + operations = [ + migrations.AddField( + model_name="algorithmcategorymap", + name="description", + field=models.TextField(blank=True, null=True), + ), + migrations.AlterField( + model_name="algorithmcategorymap", + name="data", + field=models.JSONField( + help_text="Complete metadata for each label, such as id, gbif_key, explicit index, source, etc." + ), + ), + migrations.AlterField( + model_name="algorithmcategorymap", + name="labels", + field=django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(max_length=255), + default=list, + help_text="A simple list of string labels in the correct index order used by the model.", + size=None, + ), + ), + ] diff --git a/ami/ml/migrations/0009_algorithm_task_type.py b/ami/ml/migrations/0009_algorithm_task_type.py new file mode 100644 index 000000000..6bdfda3b3 --- /dev/null +++ b/ami/ml/migrations/0009_algorithm_task_type.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.10 on 2024-12-05 01:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0008_algorithmcategorymap_description_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="algorithm", + name="task_type", + field=models.CharField( + blank=True, + choices=[ + ("detection", "Detection"), + ("segmentation", "Segmentation"), + ("classification", "Classification"), + ("embedding", "Embedding"), + ("tracking", "Tracking"), + ("tagging", "Tagging"), + ("regression", "Regression"), + ("captioning", "Captioning"), + ("generation", "Generation"), + ("translation", "Translation"), + ("summarization", "Summarization"), + ("question_answering", "Question Answering"), + ("depth_estimation", "Depth Estimation"), + ("pose_estimation", "Pose Estimation"), + ("size_estimation", "Size Estimation"), + ("other", "Other"), + ], + max_length=255, + ), + ), + ] diff --git a/ami/ml/migrations/0010_alter_algorithm_version.py b/ami/ml/migrations/0010_alter_algorithm_version.py new file mode 100644 index 000000000..84d42928e --- /dev/null +++ b/ami/ml/migrations/0010_alter_algorithm_version.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.10 on 2024-12-05 01:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0009_algorithm_task_type"), + ] + + operations = [ + migrations.AlterField( + model_name="algorithm", + name="version", + field=models.IntegerField( + default=1, help_text="An internal, sortable and incrementable version number for the model." + ), + ), + ] diff --git a/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py b/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py new file mode 100644 index 000000000..9f084ca80 --- /dev/null +++ b/ami/ml/migrations/0011_alter_algorithm_task_type_alter_algorithm_url_and_more.py @@ -0,0 +1,50 @@ +# Generated by Django 4.2.10 on 2024-12-05 04:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0010_alter_algorithm_version"), + ] + + operations = [ + migrations.AlterField( + model_name="algorithm", + name="task_type", + field=models.CharField( + choices=[ + ("detection", "Detection"), + ("segmentation", "Segmentation"), + ("classification", "Classification"), + ("embedding", "Embedding"), + ("tracking", "Tracking"), + ("tagging", "Tagging"), + ("regression", "Regression"), + ("captioning", "Captioning"), + ("generation", "Generation"), + ("translation", "Translation"), + ("summarization", "Summarization"), + ("question_answering", "Question Answering"), + ("depth_estimation", "Depth Estimation"), + ("pose_estimation", "Pose Estimation"), + ("size_estimation", "Size Estimation"), + ("other", "Other"), + ("unknown", "Unknown"), + ], + default="unknown", + max_length=255, + null=True, + ), + ), + migrations.AlterField( + model_name="algorithm", + name="url", + field=models.URLField(blank=True, null=True), + ), + migrations.AlterField( + model_name="algorithm", + name="version_name", + field=models.CharField(blank=True, max_length=255, null=True), + ), + ] diff --git a/ami/ml/migrations/0012_alter_algorithm_unique_together.py b/ami/ml/migrations/0012_alter_algorithm_unique_together.py new file mode 100644 index 000000000..1ccf7fffd --- /dev/null +++ b/ami/ml/migrations/0012_alter_algorithm_unique_together.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.10 on 2024-12-05 04:20 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0011_alter_algorithm_task_type_alter_algorithm_url_and_more"), + ] + + operations = [ + migrations.AlterUniqueTogether( + name="algorithm", + unique_together=set(), + ), + ] diff --git a/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py b/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py new file mode 100644 index 000000000..36454c7dd --- /dev/null +++ b/ami/ml/migrations/0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more.py @@ -0,0 +1,54 @@ +# Generated by Django 4.2.10 on 2024-12-06 21:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0012_alter_algorithm_unique_together"), + ] + + operations = [ + migrations.RemoveField( + model_name="algorithm", + name="url", + ), + migrations.RemoveField( + model_name="algorithmcategorymap", + name="url", + ), + migrations.AddField( + model_name="algorithm", + name="uri", + field=models.CharField( + blank=True, + help_text="A URI to the weights or model details. Could be a public web URL or object store path.", + max_length=255, + null=True, + ), + ), + migrations.AddField( + model_name="algorithmcategorymap", + name="labels_hash", + field=models.BigIntegerField( + help_text="A hash of the labels for faster comparison of label sets. Created on save.", null=True + ), + ), + migrations.AddField( + model_name="algorithmcategorymap", + name="uri", + field=models.CharField( + blank=True, + help_text="A URI to the category map file. Could be a public web URL or object store path.", + max_length=255, + null=True, + ), + ), + migrations.AlterField( + model_name="algorithmcategorymap", + name="data", + field=models.JSONField( + help_text="Complete metadata for each label, such as id, gbif_key, lookup value, source, etc." + ), + ), + ] diff --git a/ami/ml/migrations/0014_rename_model_keys.py b/ami/ml/migrations/0014_rename_model_keys.py new file mode 100644 index 000000000..1ddb104f3 --- /dev/null +++ b/ami/ml/migrations/0014_rename_model_keys.py @@ -0,0 +1,39 @@ +# Generated by Django 4.2.10 on 2024-12-06 21:15 + +import random +import string +import logging +from django.db import migrations + +logger = logging.getLogger(__name__) + + +def rename_algorithm_keys(apps, schema_editor): + """ + The current live ML backend API uses keys with underscores, but the Antenna database + uses keys with dashes. This migration renames all existing algorithm keys to use underscores + so that the keys match for new detections. + """ + Algorithm = apps.get_model("ml", "Algorithm") + algorithms = Algorithm.objects.all() + for algorithm in algorithms: + new_key = algorithm.key.replace("-", "_") + if Algorithm.objects.filter(key=new_key).exclude(id=algorithm.pk).exists(): + # Add random 6 char suffix to avoid clashes + new_key += "_" + "".join(random.choices(string.ascii_lowercase, k=6)) + if algorithm.key != new_key: + logger.info(f"Renaming algorithm key {algorithm.key} to {new_key} for algorithm {algorithm}") + algorithm.key = new_key + algorithm.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0013_remove_algorithm_url_remove_algorithmcategorymap_url_and_more"), + ] + + # Rename all existing algorithm keys to use underscores instead of dashes. Ensure there are no clashes as we go. + + operations = [ + migrations.RunPython(rename_algorithm_keys, reverse_code=migrations.RunPython.noop), + ] diff --git a/ami/ml/migrations/0015_update_existing_intermediate_classifications.py b/ami/ml/migrations/0015_update_existing_intermediate_classifications.py new file mode 100644 index 000000000..a61e5cb25 --- /dev/null +++ b/ami/ml/migrations/0015_update_existing_intermediate_classifications.py @@ -0,0 +1,46 @@ +""" +We have had the terminal field in the Classification model for a while now, but we have not been using it. +Now we need to filter classifications based on this field, so we need to update the existing classifications +where a binary / intermediate classification was used. +""" + +import logging + +from django.db import migrations, models + +logger = logging.getLogger(__name__) + + +MOTH_NONMOTH_LABELS = [ + "moth", + "non-moth", + "nonmoth", +] + + +def update_classification_labels(apps, schema_editor): + Classification = apps.get_model("main", "Classification") + + # Create regex pattern from labels to make a case-insensitive match + pattern = r"^(" + "|".join(MOTH_NONMOTH_LABELS) + ")$" + + # Log number of updated classifications + logger.info(f"\nUpdating classifications with labels matching pattern: {pattern} (case insensitive)") + + # Update only matching classifications + updated = Classification.objects.filter(taxon__name__iregex=pattern).update(terminal=False) + + logger.info(f"\nUpdated {updated} moth/non-moth classifications to terminal=False") + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0014_rename_model_keys"), + ] + + operations = [ + migrations.RunPython( + update_classification_labels, + reverse_code=migrations.RunPython.noop, + ), + ] diff --git a/ami/ml/migrations/0016_merge_20250117_2101.py b/ami/ml/migrations/0016_merge_20250117_2101.py new file mode 100644 index 000000000..63bc4e283 --- /dev/null +++ b/ami/ml/migrations/0016_merge_20250117_2101.py @@ -0,0 +1,12 @@ +# Generated by Django 4.2.10 on 2025-01-17 21:01 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0007_add_processing_service"), + ("ml", "0015_update_existing_intermediate_classifications"), + ] + + operations = [] diff --git a/ami/ml/models/__init__.py b/ami/ml/models/__init__.py index a5e716372..a4f28d671 100644 --- a/ami/ml/models/__init__.py +++ b/ami/ml/models/__init__.py @@ -1,9 +1,10 @@ -from ami.ml.models.algorithm import Algorithm +from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap from ami.ml.models.pipeline import Pipeline from ami.ml.models.processing_service import ProcessingService __all__ = [ "Algorithm", + "AlgorithmCategoryMap", "Pipeline", "ProcessingService", ] diff --git a/ami/ml/models/algorithm.py b/ami/ml/models/algorithm.py index 9cc56d4bd..7a63a0861 100644 --- a/ami/ml/models/algorithm.py +++ b/ami/ml/models/algorithm.py @@ -8,22 +8,142 @@ import typing +from django.contrib.postgres.fields import ArrayField from django.db import models from django.utils.text import slugify from ami.base.models import BaseModel +@typing.final +class AlgorithmCategoryMap(BaseModel): + """ + A list of classification labels for a given algorithm version + """ + + data = models.JSONField( + help_text="Complete metadata for each label, such as id, gbif_key, lookup value, source, etc." + ) + labels = ArrayField( + models.CharField(max_length=255), + default=list, + help_text="A simple list of string labels in the correct index order used by the model.", + ) + labels_hash = models.BigIntegerField( + help_text="A hash of the labels for faster comparison of label sets. Created on save.", + null=True, + ) + version = models.CharField(max_length=255, blank=True, null=True) + description = models.TextField(blank=True, null=True) + uri = models.CharField( + max_length=255, + blank=True, + null=True, + help_text=("A URI to the category map file. " "Could be a public web URL or object store path."), + ) + + algorithms: models.QuerySet[Algorithm] + + def __str__(self): + return f"#{self.pk} with {len(self.labels)} classes ({self.version or 'unknown version'})" + + @classmethod + def make_labels_hash(cls, labels): + """ + Create a hash from the labels for faster comparison of unique label sets + """ + return hash("".join(labels)) + + def get_category(self, label, label_field="label"): + # Can use JSON containment operators + return self.data.index(next(category for category in self.data if category[label_field] == label)) + + def with_taxa(self, category_field="label", only_indexes: list[int] | None = None): + """ + Add Taxon objects to the category map, or None if no match + + :param category_field: The field in the category data to match against the Taxon name + :return: The category map with the taxon objects added + + @TODO need a top_n parameter to limit the number of taxa to fetch + @TODO consider creating missing taxa? + """ + + from ami.main.models import Taxon + + if only_indexes: + labels_data = [self.data[i] for i in only_indexes] + labels_label = [self.labels[i] for i in only_indexes] + else: + labels_data = self.data + labels_label = self.labels + + taxa = Taxon.objects.filter(models.Q(name__in=labels_label) | models.Q(search_names__overlap=labels_label)) + taxon_map = {taxon.name: taxon for taxon in taxa} + + for category in labels_data: + taxon = taxon_map.get(category[category_field]) + category["taxon"] = taxon + + return labels_data + + def save(self, *args, **kwargs): + if not self.labels_hash: + self.labels_hash = self.make_labels_hash(self.labels) + super().save(*args, **kwargs) + + @typing.final class Algorithm(BaseModel): """A machine learning algorithm""" name = models.CharField(max_length=255) key = models.SlugField(max_length=255, unique=True) + task_type = models.CharField( + max_length=255, + default="unknown", + null=True, + choices=[ + ("detection", "Detection"), + ("segmentation", "Segmentation"), + ("classification", "Classification"), + ("embedding", "Embedding"), + ("tracking", "Tracking"), + ("tagging", "Tagging"), + ("regression", "Regression"), + ("captioning", "Captioning"), + ("generation", "Generation"), + ("translation", "Translation"), + ("summarization", "Summarization"), + ("question_answering", "Question Answering"), + ("depth_estimation", "Depth Estimation"), + ("pose_estimation", "Pose Estimation"), + ("size_estimation", "Size Estimation"), + ("other", "Other"), + ("unknown", "Unknown"), + ], + ) description = models.TextField(blank=True) - version = models.IntegerField(default=1) - version_name = models.CharField(max_length=255, blank=True) - url = models.URLField(blank=True) # URL to the model homepage, origin or docs (huggingface, wandb, etc.) + version = models.IntegerField( + default=1, + help_text="An internal, sortable and incrementable version number for the model.", + ) + version_name = models.CharField(max_length=255, blank=True, null=True) + uri = models.CharField( + max_length=255, + blank=True, + null=True, + help_text=("A URI to the weights or model details. Could be a public web URL or object store path."), + ) + + category_map = models.ForeignKey( + AlgorithmCategoryMap, + on_delete=models.CASCADE, + blank=True, + null=True, + related_name="algorithms", + default=None, + ) # api_base_url = models.URLField(blank=True) # api = models.CharField(max_length=255, blank=True) @@ -31,6 +151,9 @@ class Algorithm(BaseModel): pipelines: models.QuerySet[Pipeline] classifications: models.QuerySet[Classification] + def __str__(self): + return f'#{self.pk} "{self.name}" ({self.key}) v{self.version}' + class Meta: ordering = ["name", "version"] @@ -38,10 +161,9 @@ class Meta: ["name", "version"], ] - def __str__(self): - return f'#{self.pk} "{self.name}" ({self.key}) v{self.version}' - def save(self, *args, **kwargs): + if not self.version_name: + self.version_name = f"{self.version}" if not self.key: - self.key = slugify(self.name) + self.key = f"{slugify(self.name)}-{self.version}" super().save(*args, **kwargs) diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index c830f69bf..4a598d8e6 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -5,16 +5,19 @@ if TYPE_CHECKING: from ami.ml.models import ProcessingService +import collections +import dataclasses import logging +import time import typing +import uuid from urllib.parse import urljoin import requests -from django.db import models, transaction +from django.db import models from django.utils.text import slugify from django.utils.timezone import now from django_pydantic_field import SchemaField -from rich import print from ami.base.models import BaseModel from ami.base.schemas import ConfigurableStage, default_stages @@ -29,10 +32,20 @@ Taxon, TaxonRank, update_calculated_fields_for_events, + update_occurrence_determination, +) +from ami.ml.models.algorithm import Algorithm, AlgorithmCategoryMap +from ami.ml.schemas import ( + AlgorithmConfigResponse, + ClassificationResponse, + DetectionResponse, + PipelineRequest, + PipelineResultsResponse, + SourceImageRequest, + SourceImageResponse, ) -from ami.ml.models.algorithm import Algorithm -from ami.ml.schemas import PipelineRequest, PipelineResponse, SourceImageRequest, SourceImageResponse from ami.ml.tasks import celery_app, create_detection_images +from ami.utils.requests import create_session logger = logging.getLogger(__name__) @@ -50,20 +63,23 @@ def filter_processed_images( for image in images: existing_detections = image.detections.filter(detection_algorithm__in=pipeline_algorithms) if not existing_detections.exists(): - logger.debug(f"Image {image} has no existing detections from pipeline {pipeline}") + logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector") # If there are no existing detections from this pipeline, send the image yield image elif existing_detections.filter(classifications__isnull=True).exists(): # Check if there are detections with no classifications - logger.debug(f"Image {image} has existing detections with no classifications from pipeline {pipeline}") + logger.debug( + f"Image {image} needs processing: has existing detections with no classifications " + "from pipeline {pipeline}" + ) yield image else: # If there are existing detections with classifications, # Compare their classification algorithms to the current pipeline's algorithms - detections_needing_classification = existing_detections.exclude( - classifications__algorithm__in=pipeline_algorithms - ) - if detections_needing_classification.exists(): + pipeline_algorithm_ids = pipeline_algorithms.values_list("id", flat=True) + detection_algorithm_ids = existing_detections.values_list("classifications__algorithm_id", flat=True) + + if not set(pipeline_algorithm_ids).issubset(set(detection_algorithm_ids)): logger.debug( f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}" ) @@ -134,7 +150,7 @@ def process_images( endpoint_url: str, images: typing.Iterable[SourceImage], job_id: int | None = None, -) -> PipelineResponse: +) -> PipelineResultsResponse: """ Process images using ML pipeline API. @@ -158,7 +174,7 @@ def process_images( if not images: task_logger.info("No images to process") - return PipelineResponse( + return PipelineResultsResponse( pipeline=pipeline.slug, source_images=[], detections=[], @@ -181,18 +197,20 @@ def process_images( source_images=source_images, ) - resp = requests.post(endpoint_url, json=request_data.dict()) + session = create_session() + resp = session.post(endpoint_url, json=request_data.dict()) if not resp.ok: try: msg = resp.json()["detail"] - except Exception: + except (ValueError, KeyError): msg = str(resp.content) if job: job.logger.error(msg) else: logger.error(msg) + raise requests.HTTPError(msg) - results = PipelineResponse( + results = PipelineResultsResponse( pipeline=pipeline.slug, total_time=0, source_images=[ @@ -204,199 +222,633 @@ def process_images( return results results = resp.json() - results = PipelineResponse(**results) + results = PipelineResultsResponse(**results) if job: job.logger.debug(f"Results: {results}") detections = results.detections classifications = [classification for detection in detections for classification in detection.classifications] - if len(detections): - job.logger.info(f"Found {len(detections)} detections") - if len(classifications): - job.logger.info(f"Found {len(classifications)} classifications") + job.logger.info( + f"Pipeline results returned {len(results.source_images)} images, {len(detections)} detections, " + f"{len(classifications)} classifications" + ) return results -@celery_app.task(soft_time_limit=60 * 4, time_limit=60 * 5) -def save_results(results: PipelineResponse | None = None, results_json: str | None = None, job_id: int | None = None): +def get_or_create_algorithm_and_category_map( + algorithm_config: AlgorithmConfigResponse, + logger: logging.Logger = logger, +) -> Algorithm: """ - Save results from ML pipeline API. + Create algorithms and category maps from a ProcessingServiceInfoResponse or a PipelineConfigResponse. - @TODO break into task chunks. - @TODO rewrite this! + :param algorithm_configs: A dictionary of algorithms from the processing services' "/info" endpoint + :param logger: A logger instance from the parent function + + :return: A dictionary of algorithms used in the pipeline, keyed by the algorithm key + + @TODO this should be called when registering a pipeline, not when saving results. + But currently we don't have a way to register pipelines. """ - created_objects = [] - job = None + category_map = None + category_map_data = algorithm_config.category_map + if category_map_data: + labels_hash = AlgorithmCategoryMap.make_labels_hash(category_map_data.labels) + category_map, _created = AlgorithmCategoryMap.objects.get_or_create( + # @TODO this is creating a new category map every time + # Will create a new category map if the labels are different + labels_hash=labels_hash, + version=category_map_data.version, + defaults={ + "data": category_map_data.data, + "labels": category_map_data.labels, + "description": category_map_data.description, + "uri": category_map_data.uri, + }, + ) + if _created: + logger.info(f"Registered new category map {category_map}") + else: + logger.info(f"Assigned existing category map {category_map}") + else: + logger.warning( + f"No category map found for algorithm {algorithm_config.key} in response." + " Will attempt to create one from the classification results." + ) - if results_json: - results = PipelineResponse.parse_raw(results_json) - assert results, "No results data passed to save_results task" + algo, _created = Algorithm.objects.get_or_create( + key=algorithm_config.key, + version=algorithm_config.version, + defaults={ + "name": algorithm_config.name, + "task_type": algorithm_config.task_type, + "version_name": algorithm_config.version_name, + "uri": algorithm_config.uri, + "category_map": category_map or None, + }, + ) + + # Update fields that may have changed in the processing service, with a warning + fields_to_update = { + "task_type": algorithm_config.task_type, + "uri": algorithm_config.uri, + "category_map": category_map, + } + fields_updated = [] + for field in fields_to_update: + new_value = fields_to_update[field] + if getattr(algo, field) != new_value: + logger.warning(f"Field '{field}' changed for algorithm {algo} from {getattr(algo, field)} to {new_value}") + setattr(algo, field, new_value) + fields_updated.append(field) + algo.save(update_fields=fields_updated) + + if not algo.category_map or len(algo.category_map.data) == 0: + # Update existing algorithm that is missing a category map + algo.category_map = category_map + algo.save() - pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) if _created: - logger.warning(f"Pipeline choice returned by the Processing Service was not recognized! {pipeline}") - created_objects.append(pipeline) - algorithms_used = set() + logger.info(f"Registered new algorithm {algo}") + else: + logger.info(f"Assigned algorithm {algo}") - if job_id: - from ami.jobs.models import Job + return algo - job = Job.objects.get(pk=job_id) - job.logger.info("Saving results...") - - # collection_name = f"Images processed by {results.pipeline} pipeline" - # if job_id: - # from ami.jobs.models import Job - - # job = Job.objects.get(pk=job_id) - # collection_name = f"Images processed by {results.pipeline} pipeline for job {job.name}" - - # collection = SourceImageCollection.objects.create(name=collection_name) - # source_image_ids = [source_image.id for source_image in results.source_images] - # source_images = SourceImage.objects.filter(pk__in=source_image_ids) - # collection.images.set(source_images) - source_images = set() - - for detection_resp in results.detections: - # @TODO use bulk create, or optimize this in some way - print(detection_resp) - assert detection_resp.algorithm, "No detection algorithm was specified in the returned results." - detection_algo, _created = Algorithm.objects.get_or_create( - name=detection_resp.algorithm, + +def get_or_create_detection( + source_image: SourceImage, + detection_resp: DetectionResponse, + algorithms_used: dict[str, Algorithm], + save: bool = True, + logger: logging.Logger = logger, +) -> tuple[Detection, bool]: + """ + Create a Detection object from a DetectionResponse, or update an existing one. + + :param detection_resp: A DetectionResponse object + :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key + :param created_objects: A list to store created objects + + :return: A tuple of the Detection object and a boolean indicating whether it was created + """ + serialized_bbox = list(detection_resp.bbox.dict().values()) + detection_repr = f"Detection {detection_resp.source_image_id} {serialized_bbox}" + + assert detection_resp.algorithm, f"No detection algorithm was specified for detection {detection_repr}" + try: + detection_algo = algorithms_used[detection_resp.algorithm.key] + except KeyError: + raise ValueError( + f"Detection algorithm {detection_resp.algorithm.key} is not a known algorithm. " + "The processing service must declare it in the /info endpoint. " + f"Known algorithms: {list(algorithms_used.keys())}" ) - algorithms_used.add(detection_algo) - if _created: - created_objects.append(detection_algo) - # @TODO hmmmm what to do - source_image = SourceImage.objects.get(pk=detection_resp.source_image_id) - source_images.add(source_image) - existing_detection = Detection.objects.filter( + assert str(detection_resp.source_image_id) == str( + source_image.pk + ), f"Detection belongs to a different source image: {detection_repr}" + + existing_detection = Detection.objects.filter( + source_image=source_image, + detection_algorithm=detection_algo, + bbox=serialized_bbox, + ).first() + + # A detection may have a pre-existing crop image URL or not. + # If not, a new one will be created in a periodic background task. + if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"): + # Ensure that the crop image URL is not empty or only a slash. None is fine. + crop_url = detection_resp.crop_image_url + else: + crop_url = None + + if existing_detection: + if not existing_detection.path: + existing_detection.path = crop_url + existing_detection.save() + logger.debug(f"Updated crop_url of existing detection {existing_detection}") + detection = existing_detection + + else: + new_detection = Detection( source_image=source_image, + bbox=serialized_bbox, + timestamp=source_image.timestamp, + path=crop_url, + detection_time=detection_resp.timestamp, detection_algorithm=detection_algo, - bbox=list(detection_resp.bbox.dict().values()), - ).first() - # Ensure that the crop image URL is not empty or only a slash. None is fine. - if detection_resp.crop_image_url and detection_resp.crop_image_url.strip("/"): - crop_url = detection_resp.crop_image_url + ) + if save: + new_detection.save() + logger.debug(f"Created new detection {new_detection}") else: - crop_url = None - if existing_detection: - if not existing_detection.path: - existing_detection.path = crop_url - existing_detection.save() - print("Updated existing detection", existing_detection) - detection = existing_detection + logger.debug(f"Initialized new detection {new_detection} (not saved)") + + detection = new_detection + + created = not existing_detection + return detection, created + + +def create_detections( + detections: list[DetectionResponse], + algorithms_used: dict[str, Algorithm], + logger: logging.Logger = logger, +) -> list[Detection]: + """ + Efficiently create multiple Detection objects from a list of DetectionResponse objects, grouped by source image. + Using bulk create. + + :param detections: A list of DetectionResponse objects + :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key + :param created_objects: A list to store created objects + + :return: A list of Detection objects + """ + source_image_ids = {detection.source_image_id for detection in detections} + source_images = SourceImage.objects.filter(pk__in=source_image_ids) + source_image_map = {str(source_image.pk): source_image for source_image in source_images} + + existing_detections: list[Detection] = [] + new_detections: list[Detection] = [] + for detection_resp in detections: + source_image = source_image_map.get(detection_resp.source_image_id) + if not source_image: + logger.error(f"Source image {detection_resp.source_image_id} not found, skipping Detection creation") + continue + + detection, created = get_or_create_detection( + source_image=source_image, + detection_resp=detection_resp, + algorithms_used=algorithms_used, + save=False, + logger=logger, + ) + if created: + new_detections.append(detection) else: - new_detection = Detection.objects.create( - source_image=source_image, - bbox=list(detection_resp.bbox.dict().values()), - timestamp=source_image.timestamp, - path=crop_url, - detection_time=detection_resp.timestamp, - detection_algorithm=detection_algo, - ) - new_detection.save() - print("Created new detection", new_detection) - created_objects.append(new_detection) - detection = new_detection + existing_detections.append(detection) - for classification in detection_resp.classifications: - print(classification) + Detection.objects.bulk_create(new_detections) + # logger.info(f"Created {len(new_detections)} new detections for {len(source_image_ids)} source image(s)") + logger.info( + f"Created {len(new_detections)} new detections, updated {len(existing_detections)} existing detections, " + f"for {len(source_image_ids)} source image(s)" + ) - assert classification.algorithm, "No classification algorithm was specified in the returned results." - classification_algo, _created = Algorithm.objects.get_or_create( - name=classification.algorithm, - ) - algorithms_used.add(classification_algo) - if _created: - created_objects.append(classification_algo) + return existing_detections + new_detections - taxa_list, _created = TaxaList.objects.get_or_create( - name=f"Taxa returned by {classification_algo.name}", - ) - if _created: - created_objects.append(taxa_list) - taxon, _created = Taxon.objects.get_or_create( - name=classification.classification, - defaults={"name": classification.classification, "rank": TaxonRank.UNKNOWN}, - ) - if _created: - created_objects.append(taxon) +def create_category_map_for_classification( + classification_resp: ClassificationResponse, + logger: logging.Logger = logger, +) -> AlgorithmCategoryMap: + """ + Create a simple category map from a ClassificationResponse. + The complete category map should be created when registering the algorithm before processing images. + + :param classification: A ClassificationResponse object + + :return: The AlgorithmCategoryMap object + """ + labels = classification_resp.labels or list(map(str, range(len(classification_resp.scores)))) + category_map_data = [ + { + "label": label, + "index": i, + } + for i, label in enumerate(labels) + ] + logger.info(f"Creating placeholder category map with data: {category_map_data}") + category_map = AlgorithmCategoryMap.objects.create( + data=category_map_data, + version=classification_resp.timestamp.isoformat(), + description="Placeholder category map automatically created from classification data", + labels=labels, + ) + return category_map + + +def get_or_create_taxon_for_classification( + algorithm: Algorithm, + classification_resp: ClassificationResponse, + logger: logging.Logger = logger, +) -> Taxon: + """ + Create a Taxon object from a ClassificationResponse and add it to a TaxaList. + + :param classification: A ClassificationResponse object + + :return: The Taxon object + """ + taxa_list, _created = TaxaList.objects.get_or_create( + name=f"Taxa returned by {algorithm.name}", + ) + if _created: + logger.info(f"Created new taxa list {taxa_list}") + else: + logger.debug(f"Using existing taxa list {taxa_list}") + + # Get top label from classification scores + assert algorithm.category_map, f"No category map found for algorithm {algorithm}" + label_data: dict = algorithm.category_map.data[classification_resp.scores.index(max(classification_resp.scores))] + taxon, _created = Taxon.objects.get_or_create( + name=classification_resp.classification, + defaults={ + "name": classification_resp.classification, + "rank": label_data.get("taxon_rank", TaxonRank.UNKNOWN), + }, + ) + if _created: + logger.info(f"Registered new taxon {taxon}") + + taxa_list.taxa.add(taxon) + return taxon + + +def create_classification( + detection: Detection, + classification_resp: ClassificationResponse, + algorithms_used: dict[str, Algorithm], + save: bool = True, + logger: logging.Logger = logger, +) -> tuple[Classification, bool]: + """ + Create a Classification object from a ClassificationResponse, or update an existing one. + + :param detection: A Detection object + :param classification: A ClassificationResponse object + :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key + :param created_objects: A list to store created objects + + :return: A tuple of the Classification object and a boolean indicating whether it was created + """ + assert ( + classification_resp.algorithm + ), f"No classification algorithm was specified for classification {classification_resp}" + logger.debug(f"Processing classification {classification_resp}") + + try: + classification_algo = algorithms_used[classification_resp.algorithm.key] + except KeyError: + raise ValueError( + f"Classification algorithm {classification_resp.algorithm.key} is not a known algorithm. " + "The processing service must declare it in the /info endpoint. " + f"Known algorithms: {list(algorithms_used.keys())}" + ) + + if not classification_algo.category_map: + logger.warning( + f"Classification algorithm {classification_algo} " + "has no category map! " + "Creating one from data in the first classification if possible." + ) + category_map = create_category_map_for_classification(classification_resp, logger=logger) + classification_algo.category_map = category_map + classification_algo.save() + classification_algo.refresh_from_db() + + taxon = get_or_create_taxon_for_classification( + algorithm=classification_algo, + classification_resp=classification_resp, + logger=logger, + ) + + existing_classification = Classification.objects.filter( + detection=detection, + taxon=taxon, + algorithm=classification_algo, + score=max(classification_resp.scores), + ).first() + + if existing_classification: + # @TODO remove this after all existing classifications have been updated (added 2024-12-20) + NEW_FIELDS = ["logits", "scores", "terminal", "category_map"] + logger.debug( + "Duplicate classification found: " + f"{existing_classification.taxon} from {existing_classification.algorithm}, " + f"not creating a new one, but updating new fields if they are None ({NEW_FIELDS})" + ) + fields_to_update = [] + for field in NEW_FIELDS: + # update new fields if they are None + if getattr(existing_classification, field) is None: + fields_to_update.append(field) + if fields_to_update: + logger.info(f"Updating fields {fields_to_update} for existing classification {existing_classification}") + for field in fields_to_update: + if field == "category_map": + # Use the foreign key from the classification algorithm + setattr(existing_classification, field, classification_algo.category_map) + else: + # Get the value from the classification response + setattr(existing_classification, field, getattr(classification_resp, field)) + existing_classification.save(update_fields=fields_to_update) + logger.info(f"Updated existing classification {existing_classification}") + + classification = existing_classification + + else: + new_classification = Classification( + detection=detection, + taxon=taxon, + algorithm=classification_algo, + score=max(classification_resp.scores), + timestamp=classification_resp.timestamp or now(), + logits=classification_resp.logits, + scores=classification_resp.scores, + terminal=classification_resp.terminal, + category_map=classification_algo.category_map, + ) + classification = new_classification + + if save: + new_classification.save() + logger.debug(f"Created new classification {new_classification}") + else: + logger.debug(f"Initialized new classification {new_classification} (not saved)") - taxa_list.taxa.add(taxon) + return classification, not existing_classification - # @TODO this is asking for trouble - # shouldn't we be able to get the detection from the classification? - # also should filter by the correct detection algorithm - # or do we use the bbox as a unique identifier? - # then it doesn't matter what detection algorithm was used - new_classification, created = Classification.objects.get_or_create( +def create_classifications( + detections: list[Detection], + detection_responses: list[DetectionResponse], + algorithms_used: dict[str, Algorithm], + logger: logging.Logger = logger, + save: bool = True, +) -> list[Classification]: + """ + Efficiently create multiple Classification objects from a list of ClassificationResponse objects, + grouped by detection. + + :param detection: A Detection object + :param classifications: A list of ClassificationResponse objects + :param algorithms_used: A dictionary of algorithms used in the pipeline, keyed by the algorithm key + + :return: A list of Classification objects + + @TODO bulk create all classifications for all detections in request + """ + existing_classifications: list[Classification] = [] + new_classifications: list[Classification] = [] + + for detection, detection_resp in zip(detections, detection_responses): + for classification_resp in detection_resp.classifications: + classification, created = create_classification( detection=detection, - taxon=taxon, - algorithm=classification_algo, - score=max(classification.scores), - defaults={"timestamp": classification.timestamp or now()}, + classification_resp=classification_resp, + algorithms_used=algorithms_used, + save=False, + logger=logger, ) - if created: - # Optionally add reference to job or pipeline here - created_objects.append(new_classification) + new_classifications.append(classification) else: - # Optionally handle the case where a duplicate is found - logger.warn("Duplicate classification found, not creating a new one.") + # @TODO consider adding logits, scores and terminal state to existing classifications (new fields) + existing_classifications.append(classification) + + Classification.objects.bulk_create(new_classifications) + logger.info( + f"Created {len(new_classifications)} new classifications, updated {len(existing_classifications)} existing " + f"classifications for {len(detections)} detections." + ) + + return existing_classifications + new_classifications - # Create a new occurrence for each detection (no tracking yet) - # @TODO remove when we implement tracking + +def create_and_update_occurrences_for_detections( + detections: list[Detection], + logger: logging.Logger = logger, +): + """ + Create an Occurrence object for each Detection, and update the occurrence determination. + + Select the best terminal classification for the occurrence determination. + + :param detection: A Detection object + :param classifications: A list of Classification objects + + :return: The Occurrence object + """ + + # Group detections by source image id so we don't create duplicate occurrences + detections_by_source_image = collections.defaultdict(list) + for detection in detections: + detections_by_source_image[detection.source_image_id].append(detection) + + for source_image_id, detections in detections_by_source_image.items(): + logger.info(f"Determining occurrences for {len(detections)} detections for source image {source_image_id}") + + occurrences_to_create = [] + detections_to_update = [] + + for detection in detections: if not detection.occurrence: - occurrence = Occurrence.objects.create( - event=source_image.event, - deployment=source_image.deployment, - project=source_image.project, - determination=taxon, - determination_score=new_classification.score, + occurrence = Occurrence( + event=detection.source_image.event, + deployment=detection.source_image.deployment, + project=detection.source_image.project, ) + occurrences_to_create.append(occurrence) + logger.debug(f"Created new occurrence {occurrence} for detection {detection}") detection.occurrence = occurrence # type: ignore - detection.save() - detection.occurrence.save() + detections_to_update.append(detection) + + occurrences = Occurrence.objects.bulk_create(occurrences_to_create) + logger.info(f"Created {len(occurrences)} new occurrences") + Detection.objects.bulk_update(detections_to_update, ["occurrence"]) + logger.info(f"Updated {len(detections_to_update)} detections with occurrences") + + occurrences_to_update = [] + occurrences_to_leave = [] + for detection in detections: + assert detection.occurrence, f"No occurrence found for detection {detection}" + needs_update = update_occurrence_determination( + detection.occurrence, + current_determination=detection.occurrence.determination, + save=False, + ) + if needs_update: + occurrences_to_update.append(detection.occurrence) + else: + occurrences_to_leave.append(detection.occurrence) + + Occurrence.objects.bulk_update(occurrences_to_update, ["determination", "determination_score"]) + logger.info( + f"Updated the determination of {len(occurrences_to_update)} occurrences, " + f"left {len(occurrences_to_leave)} unchanged" + ) + + SourceImage.objects.get(pk=source_image_id).save() + + +@dataclasses.dataclass +class PipelineSaveResults: + pipeline: Pipeline + source_images: list[SourceImage] + detections: list[Detection] + classifications: list[Classification] + algorithms: dict[str, Algorithm] + total_time: float + + +@celery_app.task(soft_time_limit=60 * 4, time_limit=60 * 5) +def save_results( + results: PipelineResultsResponse | None = None, + results_json: str | None = None, + job_id: int | None = None, + return_created=False, +) -> PipelineSaveResults | None: + """ + Save results from ML pipeline API. + + @TODO Continue improving bulk create. Group everything / all loops by source image. + """ + job = None + + if results_json: + results = PipelineResultsResponse.parse_raw(results_json) + assert results, "No results data passed to save_results task" + + pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) + if _created: + logger.warning(f"Pipeline choice returned by the Processing Service was not recognized! {pipeline}") + algorithms_used = set() + + job_logger = logger + start_time = time.time() + + if job_id: + from ami.jobs.models import Job + + job = Job.objects.get(pk=job_id) + job_logger = job.logger + + # @TODO set this level back to INFO + # job_logger.setLevel(logging.DEBUG) + + if results_json: + results = PipelineResultsResponse.parse_raw(results_json) + assert results, "No results data passed to save_results task" + job_logger.info(f"Saving results from pipeline {results.pipeline}") + + results = PipelineResultsResponse.parse_obj(results.dict()) + assert results, "No results from pipeline to save" + source_images = SourceImage.objects.filter(pk__in=[int(img.id) for img in results.source_images]).distinct() + + pipeline, _created = Pipeline.objects.get_or_create(slug=results.pipeline, defaults={"name": results.pipeline}) + if _created: + job_logger.warning( + f"The pipeline returned by the ML backend was not recognized, created a placeholder: {pipeline}" + ) + + # Algorithms and category maps should be created in advance when registering the pipeline & processing service + # however they are also currently available in each pipeline results response as well. + # @TODO review if we should only use the algorithms from the pre-registered pipeline config instead of the results + algorithms_used = { + algo_key: get_or_create_algorithm_and_category_map(algo_config, logger=job_logger) + for algo_key, algo_config in results.algorithms.items() + } + + detections = create_detections( + detections=results.detections, + algorithms_used=algorithms_used, + logger=job_logger, + ) + + classifications = create_classifications( + detections=detections, + detection_responses=results.detections, + algorithms_used=algorithms_used, + logger=job_logger, + ) + + # Create a new occurrence for each detection (no tracking yet) + # @TODO remove when we implement tracking! + create_and_update_occurrences_for_detections( + detections=detections, + logger=job_logger, + ) # Update precalculated counts on source images and events - with transaction.atomic(): - for source_image in source_images: - source_image.save() + source_images = list(source_images) + logger.info(f"Updating calculated fields for {len(source_images)} source images") + for source_image in source_images: + source_image.save() image_cropping_task = create_detection_images.delay( source_image_ids=[source_image.pk for source_image in source_images], ) - if job: - job.logger.info(f"Creating detection images in sub-task {image_cropping_task.id}") + job_logger.info(f"Creating detection images in sub-task {image_cropping_task.id}") - event_ids = [img.event_id for img in source_images] + event_ids = [img.event_id for img in source_images] # type: ignore update_calculated_fields_for_events(pks=event_ids) registered_algos = pipeline.algorithms.all() - for algo in algorithms_used: + for algo in algorithms_used.values(): # This is important for tracking what objects were processed by which algorithms # to avoid reprocessing, and for tracking provenance. if algo not in registered_algos: pipeline.algorithms.add(algo) - logger.warning(f"Added unregistered algorithm {algo} to pipeline {pipeline}") + job_logger.debug(f"Added algorithm {algo} to pipeline {pipeline}") - if job: - if len(created_objects): - job.logger.info(f"Created {len(created_objects)} objects") - try: - previously_created = int(job.progress.get_stage_param("results", "objects_created").value) - job.progress.update_stage( - "results", - objects_created=previously_created + len(created_objects), - ) - except ValueError: - pass - else: - job.update_progress() + total_time = time.time() - start_time + job_logger.info(f"Saved results from pipeline {pipeline} in {total_time:.2f} seconds") + + if return_created: + """ + By default, return None because celery tasks need special handling to return objects. + """ + return PipelineSaveResults( + pipeline=pipeline, + source_images=source_images, + detections=detections, + classifications=classifications, + algorithms=algorithms_used, + total_time=total_time, + ) class PipelineStage(ConfigurableStage): @@ -501,22 +953,24 @@ def process_images(self, images: typing.Iterable[SourceImage], job_id: int | Non ) return process_images( - endpoint_url=urljoin(processing_service.endpoint_url, "/process_images"), + endpoint_url=urljoin(processing_service.endpoint_url, "/process"), pipeline=self, images=images, job_id=job_id, ) - def save_results(self, results: PipelineResponse, job_id: int | None = None): + def save_results(self, results: PipelineResultsResponse, job_id: int | None = None): return save_results(results=results, job_id=job_id) - def save_results_async(self, results: PipelineResponse, job_id: int | None = None): + def save_results_async(self, results: PipelineResultsResponse, job_id: int | None = None): # Returns an AsyncResult results_json = results.json() return save_results.delay(results_json=results_json, job_id=job_id) def save(self, *args, **kwargs): if not self.slug: - # @TODO slug may only need to be unique per project - self.slug = slugify(self.name) + # @TODO find a better way to generate unique identifiers + # consider hashing the pipeline config or using a UUID -- but both sides need to agree on the same UUID. + unique_suffix = str(uuid.uuid4())[:8] + self.slug = f"{slugify(self.name)}-v{self.version}-{unique_suffix}" return super().save(*args, **kwargs) diff --git a/ami/ml/models/processing_service.py b/ami/ml/models/processing_service.py index fb34c271c..3d66143d2 100644 --- a/ami/ml/models/processing_service.py +++ b/ami/ml/models/processing_service.py @@ -8,9 +8,8 @@ from django.db import models from ami.base.models import BaseModel -from ami.ml.models.algorithm import Algorithm -from ami.ml.models.pipeline import Pipeline -from ami.ml.schemas import PipelineRegistrationResponse, ProcessingServiceStatusResponse +from ami.ml.models.pipeline import Pipeline, get_or_create_algorithm_and_category_map +from ami.ml.schemas import PipelineRegistrationResponse, ProcessingServiceInfoResponse, ProcessingServiceStatusResponse logger = logging.getLogger(__name__) @@ -38,8 +37,6 @@ class Meta: def create_pipelines(self): # Call the status endpoint and get the pipelines/algorithms resp = self.get_status() - if resp.error: - resp.raise_for_status() pipelines_to_add = resp.pipeline_configs pipelines = [] @@ -47,12 +44,19 @@ def create_pipelines(self): algorithms_created = [] for pipeline_data in pipelines_to_add: - pipeline, created = Pipeline.objects.get_or_create( - name=pipeline_data.name, - slug=pipeline_data.slug, - version=pipeline_data.version, - description=pipeline_data.description or "", - ) + pipeline = Pipeline.objects.filter( + models.Q(slug=pipeline_data.slug) | models.Q(name=pipeline_data.name, version=pipeline_data.version) + ).first() + created = False + if not pipeline: + pipeline = Pipeline.objects.create( + slug=pipeline_data.slug, + name=pipeline_data.name, + version=pipeline_data.version, + description=pipeline_data.description or "", + ) + created = True + pipeline.projects.add(*self.projects.all()) self.pipelines.add(pipeline) @@ -62,13 +66,13 @@ def create_pipelines(self): else: logger.info(f"Using existing pipeline {pipeline.name}.") + existing_algorithms = pipeline.algorithms.all() for algorithm_data in pipeline_data.algorithms: - algorithm, created = Algorithm.objects.get_or_create(name=algorithm_data.name, key=algorithm_data.key) - pipeline.algorithms.add(algorithm) - - if created: - logger.info(f"Successfully created algorithm {algorithm.name}.") - algorithms_created.append(algorithm.name) + algorithm = get_or_create_algorithm_and_category_map(algorithm_data, logger=logger) + if algorithm not in existing_algorithms: + logger.info(f"Registered new algorithm {algorithm.name} to pipeline {pipeline.name}.") + pipeline.algorithms.add(algorithm) + pipelines_created.append(algorithm.key) else: logger.info(f"Using existing algorithm {algorithm.name}.") @@ -94,28 +98,33 @@ def get_status(self): resp = requests.get(info_url) resp.raise_for_status() except requests.exceptions.RequestException as e: + latency = time.time() - start_time self.last_checked_live = False + self.last_checked_latency = latency self.save() error = f"Error connecting to {info_url}: {e}" logger.error(error) - first_response_time = time.time() - latency = first_response_time - start_time - return ProcessingServiceStatusResponse( + error=error, timestamp=timestamp, request_successful=False, + server_live=False, + pipelines_online=[], + pipeline_configs=[], endpoint_url=self.endpoint_url, - error=error, latency=latency, ) - pipeline_configs = resp.json() - server_live = requests.get(urljoin(self.endpoint_url, "livez")).json().get("status") - pipelines_online = requests.get(urljoin(self.endpoint_url, "readyz")).json().get("status") + info_data = ProcessingServiceInfoResponse.parse_obj(resp.json()) + pipeline_configs = info_data.pipelines + + # @TODO these are likely extra requests that could be avoided + # @TODO add schemas for these if we keep them + server_live: bool = requests.get(urljoin(self.endpoint_url, "livez")).json().get("status", False) + pipelines_online: list[str] = requests.get(urljoin(self.endpoint_url, "readyz")).json().get("status", []) - first_response_time = time.time() - latency = first_response_time - start_time + latency = time.time() - start_time self.last_checked_live = server_live self.last_checked_latency = latency self.save() diff --git a/ami/ml/schemas.py b/ami/ml/schemas.py index e700e8076..97e0ca480 100644 --- a/ami/ml/schemas.py +++ b/ami/ml/schemas.py @@ -1,8 +1,12 @@ import datetime +import logging import typing import pydantic +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + class BoundingBox(pydantic.BaseModel): x1: float @@ -11,25 +15,108 @@ class BoundingBox(pydantic.BaseModel): y2: float @classmethod - def from_coords(cls, coords: list[float]): + def from_coords(cls, coords: list[float]) -> "BoundingBox": return cls(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]) + def to_string(self) -> str: + return f"{self.x1},{self.y1},{self.x2},{self.y2}" + + def to_path(self) -> str: + return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]]) + + def to_tuple(self): + return (self.x1, self.y1, self.x2, self.y2) + + +class AlgorithmReference(pydantic.BaseModel): + name: str + key: str + + +class AlgorithmCategoryMapResponse(pydantic.BaseModel): + data: list[dict] = pydantic.Field( + default_factory=dict, + description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.", + examples=[ + [ + {"label": "Moth", "index": 0, "gbif_key": 1234}, + {"label": "Not a moth", "index": 1, "gbif_key": 5678}, + ] + ], + ) + labels: list[str] = pydantic.Field( + default_factory=list, + description="A simple list of string labels, in the correct index order used by the model.", + examples=[["Moth", "Not a moth"]], + ) + version: str | None = pydantic.Field( + default=None, + description="The version of the category map. Can be a descriptive string or a version number.", + examples=["LepNet2021-with-2023-mods"], + ) + description: str | None = pydantic.Field( + default=None, + description="A description of the category map used to train. e.g. source, purpose and modifications.", + examples=["LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations."], + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the category map file, could be a public web URL or object store path.", + ) + + +class AlgorithmConfigResponse(pydantic.BaseModel): + name: str + key: str = pydantic.Field( + description=("A unique key for an algorithm to lookup the category map (class list) and other metadata."), + ) + description: str | None = None + task_type: str | None = pydantic.Field( + default=None, + description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.", + examples=["detection", "classification", "segmentation", "embedding"], + ) + version: int = pydantic.Field( + default=1, + description="A sortable version number for the model. Increment this number when the model is updated.", + ) + version_name: str | None = pydantic.Field( + default=None, + description="A complete version name e.g. '2021-01-01', 'LepNet2021'.", + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the weight or model details, could be a public web URL or object store path.", + ) + category_map: AlgorithmCategoryMapResponse | None = None + + class Config: + extra = "ignore" + class ClassificationResponse(pydantic.BaseModel): classification: str - labels: list[str] = [] + labels: list[str] | None = pydantic.Field( + default=None, + description=( + "A list of all possible labels for the model, in the correct order. " + "Omitted if the model has too many labels to include for each classification in the response. " + "Use the category map from the algorithm to get the full list of labels and metadata." + ), + ) scores: list[float] = [] + logits: list[float] | None = None inference_time: float | None = None - algorithm: str | None = None - timestamp: datetime.datetime + algorithm: AlgorithmReference terminal: bool = True + timestamp: datetime.datetime class DetectionResponse(pydantic.BaseModel): source_image_id: str bbox: BoundingBox inference_time: float | None = None - algorithm: str | None = None + algorithm: AlgorithmReference timestamp: datetime.datetime crop_image_url: str | None = None classifications: list[ClassificationResponse] = [] @@ -46,6 +133,9 @@ class SourceImageResponse(pydantic.BaseModel): id: str url: str + class Config: + extra = "ignore" + KnownPipelineChoices = typing.Literal[ "panama_moths_2023", @@ -59,10 +149,15 @@ class PipelineRequest(pydantic.BaseModel): source_images: list[SourceImageRequest] -class PipelineResponse(pydantic.BaseModel): +class PipelineResultsResponse(pydantic.BaseModel): # pipeline: PipelineChoice pipeline: str - total_time: float | None + algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field( + default_factory=dict, + description="A dictionary of all algorithms used in the pipeline, including their class list and other " + "metadata, keyed by the algorithm key.", + ) + total_time: float source_images: list[SourceImageResponse] detections: list[DetectionResponse] errors: list | str | None = None @@ -85,11 +180,6 @@ class PipelineStage(pydantic.BaseModel): description: str | None = None -class AlgorithmConfig(pydantic.BaseModel): - name: str - key: str - - class PipelineConfig(pydantic.BaseModel): """A configurable pipeline.""" @@ -97,17 +187,32 @@ class PipelineConfig(pydantic.BaseModel): slug: str version: int description: str | None = None - algorithms: list[AlgorithmConfig] = [] + algorithms: list[AlgorithmConfigResponse] = [] stages: list[PipelineStage] = [] +class ProcessingServiceInfoResponse(pydantic.BaseModel): + """ + Information about the processing service returned from the Processing Service backend. + """ + + name: str + description: str | None = None + pipelines: list[PipelineConfig] = [] + algorithms: list[AlgorithmConfigResponse] = [] + + class ProcessingServiceStatusResponse(pydantic.BaseModel): + """ + Status response returned by the Antenna API about the Processing Service. + """ + timestamp: datetime.datetime request_successful: bool pipeline_configs: list[PipelineConfig] = [] error: str | None = None server_live: bool | None = None - pipelines_online: list[str] | str = "pipelines unavailable" + pipelines_online: list[str] = [] endpoint_url: str latency: float diff --git a/ami/ml/serializers.py b/ami/ml/serializers.py index 24e08a064..1611b3ea4 100644 --- a/ami/ml/serializers.py +++ b/ami/ml/serializers.py @@ -4,11 +4,26 @@ from ami.main.api.serializers import DefaultSerializer from ami.main.models import Project -from .models.algorithm import Algorithm +from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline, PipelineStage from .models.processing_service import ProcessingService +class AlgorithmCategoryMapSerializer(DefaultSerializer): + class Meta: + model = AlgorithmCategoryMap + fields = [ + "id", + "labels", + "data", + "algorithms", + "version", + "uri", + "created_at", + "updated_at", + ] + + class AlgorithmSerializer(DefaultSerializer): class Meta: model = Algorithm @@ -18,9 +33,11 @@ class Meta: "name", "key", "description", - "url", + "uri", "version", "version_name", + "task_type", + "category_map", "created_at", "updated_at", ] diff --git a/ami/ml/tasks.py b/ami/ml/tasks.py index a38847c9d..972233251 100644 --- a/ami/ml/tasks.py +++ b/ami/ml/tasks.py @@ -1,4 +1,5 @@ import logging +import time from ami.ml.media import create_detection_images_from_source_image from ami.tasks import default_soft_time_limit, default_time_limit @@ -42,6 +43,8 @@ def process_source_images_async(pipeline_choice: str, endpoint_url: str, image_i def create_detection_images(source_image_ids: list[int]): from ami.main.models import SourceImage + start_time = time.time() + logger.debug(f"Creating detection images for {len(source_image_ids)} capture(s)") for source_image in SourceImage.objects.filter(pk__in=source_image_ids): @@ -51,6 +54,9 @@ def create_detection_images(source_image_ids: list[int]): except Exception as e: logger.error(f"Error creating detection images for SourceImage {source_image.pk}: {str(e)}") + total_time = time.time() - start_time + logger.info(f"Created detection images for {len(source_image_ids)} capture(s) in {total_time:.2f} seconds") + @celery_app.task(soft_time_limit=default_soft_time_limit, time_limit=default_time_limit) def remove_duplicate_classifications(project_id: int | None = None, dry_run: bool = False) -> int: diff --git a/ami/ml/tests.py b/ami/ml/tests.py index b2dc5fb70..2fbd5e723 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1,21 +1,24 @@ import datetime +import unittest from django.test import TestCase from rest_framework.test import APIRequestFactory, APITestCase -from rich import print from ami.base.serializers import reverse_with_params from ami.main.models import Classification, Detection, Project, SourceImage, SourceImageCollection from ami.ml.models import Algorithm, Pipeline, ProcessingService -from ami.ml.models.pipeline import collect_images, save_results +from ami.ml.models.pipeline import collect_images, get_or_create_algorithm_and_category_map, save_results from ami.ml.schemas import ( + AlgorithmConfigResponse, + AlgorithmReference, BoundingBox, ClassificationResponse, DetectionResponse, - PipelineResponse, + PipelineResultsResponse, SourceImageResponse, ) from ami.tests.fixtures.main import create_captures_from_files, create_processing_service, setup_test_project +from ami.tests.fixtures.ml import ALGORITHM_CHOICES from ami.users.models import User @@ -104,7 +107,8 @@ def setUp(self): self.test_images = [image for image, frame in self.captures] self.processing_service_instance = create_processing_service(self.project) self.processing_service = self.processing_service_instance - self.pipeline = self.processing_service_instance.pipelines.all().filter(slug="constant").first() + assert self.processing_service_instance.pipelines.exists() + self.pipeline = self.processing_service_instance.pipelines.all().get(slug="constant") def test_run_pipeline(self): # Send images to Processing Service to process and return detections @@ -112,6 +116,72 @@ def test_run_pipeline(self): pipeline_response = self.pipeline.process_images(self.test_images, job_id=None) assert pipeline_response.detections + def test_created_category_maps(self): + # Send images to ML backend to process and return detections + assert self.pipeline + pipeline_response = self.pipeline.process_images(self.test_images) + save_results(pipeline_response, return_created=True) + + source_images = SourceImage.objects.filter(pk__in=[image.id for image in pipeline_response.source_images]) + detections = Detection.objects.filter(source_image__in=source_images).select_related( + "detection_algorithm", + "detection_algorithm__category_map", + ) + assert detections.count() > 0 + for detection in detections: + # No detection algorithm should have category map at this time (but this may change!) + assert detection.detection_algorithm + assert detection.detection_algorithm.category_map is None + + # Ensure that all classification algorithms have a category map + classification_taxa = set() + for classification in detection.classifications.all().select_related( + "algorithm", + "algorithm__category_map", + ): + assert classification.algorithm is not None + assert classification.category_map is not None + assert classification.algorithm.category_map == classification.category_map + + _, top_score = list(classification.predictions(sort=True))[0] + assert top_score == classification.score + + top_taxon, top_taxon_score = list(classification.predictions_with_taxa(sort=True))[0] + assert top_taxon == classification.taxon + assert top_taxon_score == classification.score + + classification_taxa.add(top_taxon) + + # Check the occurrence determination taxon + assert detection.occurrence + assert detection.occurrence.determination in classification_taxa + + def test_alignment_of_predictions_and_category_map(self): + # Ensure that the scores and labels are aligned + pipeline = self.processing_service_instance.pipelines.all().get(slug="random") + pipeline_response = pipeline.process_images(self.test_images) + results = save_results(pipeline_response, return_created=True) + assert results is not None, "Expecected results to be returned in a PipelineSaveResults object" + assert results.classifications, "Expected classifications to be returned in the results" + for classification in results.classifications: + assert classification.scores + taxa_with_scores = list(classification.predictions_with_taxa(sort=True)) + assert taxa_with_scores + assert classification.score == taxa_with_scores[0][1] + assert classification.taxon == taxa_with_scores[0][0] + + def test_top_n_alignment(self): + # Ensure that the top_n parameter works + pipeline = self.processing_service_instance.pipelines.all().get(slug="random") + pipeline_response = pipeline.process_images(self.test_images) + results = save_results(pipeline_response, return_created=True) + assert results is not None, "Expecected results to be returned in a PipelineSaveResults object" + assert results.classifications, "Expected classifications to be returned in the results" + for classification in results.classifications: + top_n = classification.top_n(n=3) + assert classification.score == top_n[0]["score"] + assert classification.taxon == top_n[0]["taxon"] + class TestPipeline(TestCase): def setUp(self): @@ -131,39 +201,72 @@ def setUp(self): self.pipeline = Pipeline.objects.create( name="Test Pipeline", ) + self.algorithms = { - "detector": Algorithm.objects.create(name="Test Object Detector"), - "binary_classifier": Algorithm.objects.create(name="Test Filter"), - "species_classifier": Algorithm.objects.create(name="Test Classifier"), + key: get_or_create_algorithm_and_category_map(val) for key, val in ALGORITHM_CHOICES.items() } - self.pipeline.algorithms.set(self.algorithms.values()) + self.pipeline.algorithms.set([algo for algo in self.algorithms.values()]) def test_create_pipeline(self): - self.assertEqual(self.pipeline.slug, "test-pipeline") - self.assertEqual(self.pipeline.algorithms.count(), 3) + assert self.pipeline.slug.startswith("test-pipeline") + self.assertEqual(self.pipeline.algorithms.count(), len(ALGORITHM_CHOICES)) for algorithm in self.pipeline.algorithms.all(): - self.assertIn(algorithm.key, ["test-object-detector", "test-filter", "test-classifier"]) + assert isinstance(algorithm, Algorithm) + self.assertIn(algorithm.key, [algo.key for algo in ALGORITHM_CHOICES.values()]) def test_collect_images(self): images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) assert len(images) == 2 - def fake_pipeline_results(self, source_images: list[SourceImage], pipeline: Pipeline): + def fake_pipeline_results( + self, + source_images: list[SourceImage], + pipeline: Pipeline, + alt_species_classifier: AlgorithmConfigResponse | None = None, + ): + # @TODO use the pipeline passed in to get the algorithms source_image_results = [SourceImageResponse(id=image.pk, url=image.path) for image in source_images] + detector = ALGORITHM_CHOICES["random-detector"] + binary_classifier = ALGORITHM_CHOICES["random-binary-classifier"] + assert binary_classifier.category_map + + if alt_species_classifier is None: + species_classifier = ALGORITHM_CHOICES["random-species-classifier"] + else: + species_classifier = alt_species_classifier + assert species_classifier.category_map + detection_results = [ DetectionResponse( source_image_id=image.pk, bbox=BoundingBox(x1=0.0, y1=0.0, x2=1.0, y2=1.0), inference_time=0.4, - algorithm=self.algorithms["detector"].name, + algorithm=AlgorithmReference( + name=detector.name, + key=detector.key, + ), timestamp=datetime.datetime.now(), classifications=[ ClassificationResponse( - classification="Test taxon", - labels=["Test taxon"], + classification=binary_classifier.category_map.labels[0], + labels=None, + scores=[0.9213], + algorithm=AlgorithmReference( + name=binary_classifier.name, + key=binary_classifier.key, + ), + timestamp=datetime.datetime.now(), + terminal=False, + ), + ClassificationResponse( + classification=species_classifier.category_map.labels[0], + labels=None, scores=[0.64333], - algorithm=self.algorithms["species_classifier"].name, + algorithm=AlgorithmReference( + name=species_classifier.name, + key=species_classifier.key, + ), timestamp=datetime.datetime.now(), terminal=True, ), @@ -171,22 +274,27 @@ def fake_pipeline_results(self, source_images: list[SourceImage], pipeline: Pipe ) for image in self.test_images ] - fake_results = PipelineResponse( - pipeline=self.pipeline.slug, - total_time=0.0, + fake_results = PipelineResultsResponse( + pipeline=pipeline.slug, + algorithms={ + detector.key: detector, + binary_classifier.key: binary_classifier, + species_classifier.key: species_classifier, + }, + total_time=0.01, source_images=source_image_results, detections=detection_results, ) return fake_results def test_save_results(self): - saved_objects = save_results(self.fake_pipeline_results(self.test_images, self.pipeline)) + results = self.fake_pipeline_results(self.test_images, self.pipeline) + save_results(results) for image in self.test_images: image.save() self.assertEqual(image.detections_count, 1) - print(saved_objects) # @TODO test the cached counts for detections, etc are updated on Events, Deployments, etc. def no_test_skip_existing_results(self): @@ -196,7 +304,7 @@ def no_test_skip_existing_results(self): total_images = len(images) self.assertEqual(total_images, self.image_collection.images.count()) - save_results(self.fake_pipeline_results(images, self.pipeline)) + save_results(self.fake_pipeline_results(images, self.pipeline), return_created=True) images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) @@ -221,17 +329,22 @@ def test_skip_existing_with_new_detector(self): images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) total_images = len(images) self.assertEqual(total_images, self.image_collection.images.count()) - save_results(self.fake_pipeline_results(images, self.pipeline)) + pipeline_response = self.fake_pipeline_results(images, self.pipeline) + save_results(pipeline_response) + # Find the fist algo used where task_type is classification + classifiers = [algo for algo in pipeline_response.algorithms.values() if algo.task_type == "classification"] + last_classifier = Algorithm.objects.get(key=classifiers[-1].key) self.pipeline.algorithms.set( [ Algorithm.objects.create(name="NEW Object Detector 2.0"), - self.algorithms["species_classifier"], # Same classifier + last_classifier, ] ) images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) remaining_images_to_process = len(images_again) self.assertEqual(remaining_images_to_process, total_images) + @unittest.skip("Not implemented yet") def test_skip_existing_with_new_classifier(self): """ @TODO add support for skipping the detection model if only the classifier has changed. @@ -239,10 +352,16 @@ def test_skip_existing_with_new_classifier(self): images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) total_images = len(images) self.assertEqual(total_images, self.image_collection.images.count()) - save_results(self.fake_pipeline_results(images, self.pipeline)) + pipeline_response = self.fake_pipeline_results(images, self.pipeline) + # Find the fist algo used where task_type is detection + first_detector_in_response = next( + algo for algo in pipeline_response.algorithms.values() if algo.task_type == "detection" + ) + first_detector = Algorithm.objects.get(key=first_detector_in_response.key) + save_results(pipeline_response) self.pipeline.algorithms.set( [ - self.algorithms["detector"], # Same object detector + first_detector, Algorithm.objects.create(name="NEW Classifier 2.0"), ] ) @@ -259,14 +378,21 @@ def _test_skip_existing_per_batch_during_processing(self): def test_unknown_algorithm_returned_by_processing_service(self): fake_results = self.fake_pipeline_results(self.test_images, self.pipeline) - new_detector_name = "Unknown Detector 5.1b-mobile" - new_classifier_name = "Unknown Classifier 3.0b-mega" + new_detector = AlgorithmConfigResponse( + name="Unknown Detector 5.1b-mobile", key="unknown-detector", task_type="detection" + ) + new_classifier = AlgorithmConfigResponse( + name="Unknown Classifier 3.0b-mega", key="unknown-classifier", task_type="classification" + ) + + fake_results.algorithms[new_detector.key] = new_detector + fake_results.algorithms[new_classifier.key] = new_classifier for detection in fake_results.detections: - detection.algorithm = new_detector_name + detection.algorithm = AlgorithmReference(name=new_detector.name, key=new_detector.key) for classification in detection.classifications: - classification.algorithm = new_classifier_name + classification.algorithm = AlgorithmReference(name=new_classifier.name, key=new_classifier.key) current_total_algorithm_count = Algorithm.objects.count() @@ -278,39 +404,50 @@ def test_unknown_algorithm_returned_by_processing_service(self): self.assertEqual(new_algorithm_count, current_total_algorithm_count + 2) # Ensure new algorithms were also added to the pipeline - self.assertTrue(self.pipeline.algorithms.filter(name=new_detector_name).exists()) - self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier_name).exists()) + self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name, key=new_detector.key).exists()) + self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name, key=new_classifier.key).exists()) - def no_test_reprocessing_after_unknown_algorithm_added(self): + @unittest.skip("Not implemented yet") + def test_reprocessing_after_unknown_algorithm_added(self): # @TODO fix issue with "None" algorithm on some detections images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) - saved_objects = save_results(self.fake_pipeline_results(images, self.pipeline)) + save_results(self.fake_pipeline_results(images, self.pipeline)) + + new_detector = AlgorithmConfigResponse( + name="Unknown Detector 5.1b-mobile", key="unknown-detector", task_type="detection" + ) + new_classifier = AlgorithmConfigResponse( + name="Unknown Classifier 3.0b-mega", key="unknown-classifier", task_type="classification" + ) - new_detector_name = "Unknown Detector 5.1b-mobile" - new_classifier_name = "Unknown Classifier 3.0b-mega" fake_results = self.fake_pipeline_results(images, self.pipeline) + # Change the algorithm names to unknown ones for detection in fake_results.detections: - detection.algorithm = new_detector_name + detection.algorithm = AlgorithmReference(name=new_detector.name, key=new_detector.key) for classification in detection.classifications: - classification.algorithm = new_classifier_name + classification.algorithm = AlgorithmReference(name=new_classifier.name, key=new_classifier.key) + + fake_results.algorithms[new_detector.key] = new_detector + fake_results.algorithms[new_classifier.key] = new_classifier # print("FAKE RESULTS") # print(fake_results) # print("END FAKE RESULTS") - saved_objects = save_results(fake_results) or [] - saved_detections = [obj for obj in saved_objects if isinstance(obj, Detection)] - saved_classifications = [obj for obj in saved_objects if isinstance(obj, Classification)] + saved_objects = save_results(fake_results, return_created=True) + assert saved_objects is not None + saved_detections = saved_objects.detections + saved_classifications = saved_objects.classifications for obj in saved_detections: assert obj.detection_algorithm # For type checker, not the test # Ensure the new detector was used for the detection - self.assertEqual(obj.detection_algorithm.name, new_detector_name) + self.assertEqual(obj.detection_algorithm.name, new_detector.name) # Ensure each detection has classification objects self.assertTrue(obj.classifications.exists()) @@ -323,21 +460,21 @@ def no_test_reprocessing_after_unknown_algorithm_added(self): assert obj.algorithm # For type checker, not the test # Ensure the new classifier was used for the classification - self.assertEqual(obj.algorithm.name, new_classifier_name) + self.assertEqual(obj.algorithm.name, new_classifier.name) # Ensure each classification has the correct detection object self.assertIn(obj.detection, saved_detections, "Wrong detection object for classification object.") # Ensure new algorithms were added to the pipeline - self.assertTrue(self.pipeline.algorithms.filter(name=new_detector_name).exists()) - self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier_name).exists()) + self.assertTrue(self.pipeline.algorithms.filter(name=new_detector.name).exists()) + self.assertTrue(self.pipeline.algorithms.filter(name=new_classifier.name).exists()) detection_algos_used = Detection.objects.all().values_list("detection_algorithm__name", flat=True).distinct() - self.assertTrue(new_detector_name in detection_algos_used) + self.assertTrue(new_detector.name in detection_algos_used) # Ensure None is not in the list self.assertFalse(None in detection_algos_used) classification_algos_used = Classification.objects.all().values_list("algorithm__name", flat=True) - self.assertTrue(new_classifier_name in classification_algos_used) + self.assertTrue(new_classifier.name in classification_algos_used) # Ensure None is not in the list self.assertFalse(None in classification_algos_used) @@ -345,3 +482,93 @@ def no_test_reprocessing_after_unknown_algorithm_added(self): images_again = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) remaining_images_to_process = len(images_again) self.assertEqual(remaining_images_to_process, 0) + + def test_yes_reprocess_if_new_terminal_algorithm_same_intermediate(self): + """ + Test two pipelines with the same detector and same moth/non-moth classifier, but a new species classifier. + + The first pipeline should process the images and save the results. + The second pipeline should reprocess the images. + """ + + images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) + assert len(images), "No images to process" + + detector = Algorithm.objects.get(key="random-detector") + binary_classifier = Algorithm.objects.get(key="random-binary-classifier") + old_species_classifier = Algorithm.objects.get(key="random-species-classifier") + + Detection.objects.all().delete() + results = save_results(self.fake_pipeline_results(images, self.pipeline), return_created=True) + assert results is not None, "Expecected results to be returned in a PipelineSaveResults object" + + for raw_detection in results.detections: + self.assertEqual(raw_detection.detection_algorithm, detector) + + # Ensure all results have the binary classifier and the old species classifier + for saved_detection in Detection.objects.all(): + self.assertEqual(saved_detection.detection_algorithm, detector) + # Assert that the binary classifier was used + self.assertTrue( + saved_detection.classifications.filter(algorithm=binary_classifier).exists(), + "Binary classifier not used in first run", + ) + # Assert that the old species classifier was used + self.assertTrue( + saved_detection.classifications.filter(algorithm=old_species_classifier).exists(), + "Old species classifier not used in first run", + ) + + # Get another species classifier + new_species_classifier_key = "constant-species-classifier" + new_species_classifier = Algorithm.objects.get(key=new_species_classifier_key) + # new_species_classifier_response = ALGORITHM_CHOICES[new_species_classifier_key] + + # Create a new pipeline with the same detector and the new species classifier + new_pipeline = Pipeline.objects.create( + name="New Pipeline", + ) + + new_pipeline.algorithms.set( + [ + detector, + binary_classifier, + new_species_classifier, + ] + ) + + # Process the images with the new pipeline + images_again = list(collect_images(collection=self.image_collection, pipeline=new_pipeline)) + remaining_images_to_process = len(images_again) + self.assertEqual(remaining_images_to_process, len(images), "Images not re-processed with new pipeline") + + +class TestAlgorithmCategoryMaps(TestCase): + def setUp(self): + self.algorithm_responses = { + key: get_or_create_algorithm_and_category_map(val) for key, val in ALGORITHM_CHOICES.items() + } + self.algorithms = {key: Algorithm.objects.get(key=key) for key in ALGORITHM_CHOICES.keys()} + + def test_create_algorithms_and_category_map(self): + assert len(self.algorithms) > 0 + assert ( + Algorithm.objects.filter( + key__in=self.algorithms.keys(), + ) + .exclude(category_map=None) + .count() + ) > 0 + + def test_algorithm_category_maps(self): + for algorithm in Algorithm.objects.filter( + key__in=self.algorithms.keys(), + ).exclude(category_map=None): + assert algorithm.category_map # For type checker, not the test + assert algorithm.category_map.labels + assert algorithm.category_map.labels_hash + assert algorithm.category_map.data + + # Ensure the full labels in the data match the simple, ordered list of labels + sorted_data = sorted(algorithm.category_map.data, key=lambda x: x["index"]) + assert [category["label"] for category in sorted_data] == algorithm.category_map.labels diff --git a/ami/ml/views.py b/ami/ml/views.py index 0ffafa525..8f4851f83 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -12,10 +12,15 @@ from ami.main.models import SourceImage from ami.utils.requests import get_active_project, project_id_doc_param -from .models.algorithm import Algorithm +from .models.algorithm import Algorithm, AlgorithmCategoryMap from .models.pipeline import Pipeline from .models.processing_service import ProcessingService -from .serializers import AlgorithmSerializer, PipelineSerializer, ProcessingServiceSerializer +from .serializers import ( + AlgorithmCategoryMapSerializer, + AlgorithmSerializer, + PipelineSerializer, + ProcessingServiceSerializer, +) logger = logging.getLogger(__name__) @@ -37,6 +42,22 @@ class AlgorithmViewSet(DefaultViewSet): search_fields = ["name"] +class AlgorithmCategoryMapViewSet(DefaultViewSet): + """ + API endpoint that allows algorithm category maps to be viewed or edited. + """ + + queryset = AlgorithmCategoryMap.objects.all() + serializer_class = AlgorithmCategoryMapSerializer + filterset_fields = ["algorithms"] + ordering_fields = [ + "algorithms", + "created_at", + "updated_at", + "version", + ] + + class PipelineViewSet(DefaultViewSet): """ API endpoint that allows pipelines to be viewed or edited. diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index 9f4a0613e..185b1095c 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -42,7 +42,8 @@ def create_processing_service(project): processing_service_to_add = { "name": "Test Processing Service", "projects": [{"name": project.name}], - "endpoint_url": "http://processing_service:2000", + # "endpoint_url": "http://processing_service:2000", + "endpoint_url": "http://ml_backend:2000", } processing_service, created = ProcessingService.objects.get_or_create( diff --git a/ami/tests/fixtures/ml.py b/ami/tests/fixtures/ml.py new file mode 100644 index 000000000..f4bd7cb54 --- /dev/null +++ b/ami/tests/fixtures/ml.py @@ -0,0 +1,81 @@ +from ami.ml.schemas import AlgorithmCategoryMapResponse, AlgorithmConfigResponse + +RANDOM_DETECTOR = AlgorithmConfigResponse( + name="Random Detector", + key="random-detector", + task_type="detection", + description="Return bounding boxes at random locations within the image bounds.", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-detector", + category_map=None, +) + +RANDOM_BINARY_CLASSIFIER = AlgorithmConfigResponse( + name="Random binary classifier", + key="random-binary-classifier", + task_type="classification", + description="Randomly return a classification of 'Moth' or 'Not a moth'", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-binary-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + {"index": 0, "gbif_key": "1234", "label": "Moth", "source": "manual"}, + {"index": 1, "gbif_key": "4543", "label": "Not a moth", "source": "manual"}, + ], + labels=["Moth", "Not a moth"], + version="v1", + description="Class mapping for a simple binary classifier", + uri="https://huggingface.co/RolnickLab/random-binary-classifier/classes.txt", + ), +) + +RANDOM_SPECIES_CLASSIFIER = AlgorithmConfigResponse( + name="Random species classifier", + key="random-species-classifier", + task_type="classification", + description="A random species classifier", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-species-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + {"index": 0, "gbif_key": "1234", "label": "Vanessa atalanta", "source": "manual"}, + {"index": 1, "gbif_key": "4543", "label": "Vanessa cardui", "source": "manual"}, + {"index": 2, "gbif_key": "7890", "label": "Vanessa itea", "source": "manual"}, + ], + labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"], + version="v1", + description="", + uri="https://huggigface.co/RolnickLab/random-species-classifier/classes.txt", + ), +) + + +CONSTANT_SPECIES_CLASSIFIER = AlgorithmConfigResponse( + name="Constant species classifier", + key="constant-species-classifier", + task_type="classification", + description="A species classifier that always returns the same species", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/constant-species-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + {"index": 0, "gbif_key": "1234", "label": "Vanessa atalanta", "source": "manual"}, + {"index": 1, "gbif_key": "4543", "label": "Vanessa cardui", "source": "manual"}, + {"index": 2, "gbif_key": "7890", "label": "Vanessa itea", "source": "manual"}, + ], + labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"], + version="v1", + description="", + uri="https://huggigface.co/RolnickLab/constant-species-classifier/classes.txt", + ), +) +ALGORITHM_CHOICES = { + RANDOM_DETECTOR.key: RANDOM_DETECTOR, + RANDOM_BINARY_CLASSIFIER.key: RANDOM_BINARY_CLASSIFIER, + RANDOM_SPECIES_CLASSIFIER.key: RANDOM_SPECIES_CLASSIFIER, + CONSTANT_SPECIES_CLASSIFIER.key: CONSTANT_SPECIES_CLASSIFIER, +} diff --git a/ami/tests/fixtures/storage.py b/ami/tests/fixtures/storage.py index 52675919e..d53c782ff 100644 --- a/ami/tests/fixtures/storage.py +++ b/ami/tests/fixtures/storage.py @@ -42,7 +42,7 @@ def populate_bucket( config: s3.S3Config, subdir: str = "deployment_1", num_nights: int = 2, - images_per_day: int = 6, + images_per_day: int = 3, minutes_interval: int = 45, minutes_interval_variation: int = 10, skip_existing: bool = True, diff --git a/ami/utils/requests.py b/ami/utils/requests.py index 50edb0652..969c0d0aa 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -1,11 +1,53 @@ +import requests from django.forms import FloatField from drf_spectacular.utils import OpenApiParameter +from requests.adapters import HTTPAdapter from rest_framework.request import Request +from urllib3.util import Retry from ami.main.models import Project +def create_session( + retries: int = 3, + backoff_factor: int = 2, + status_forcelist: tuple[int, ...] = (500, 502, 503, 504), +) -> requests.Session: + """ + Create a requests Session with retry capabilities. + + Args: + retries: Maximum number of retries + backoff_factor: Backoff factor for retries + status_forcelist: HTTP status codes to retry on + + Returns: + Session configured with retry behavior + """ + session = requests.Session() + retry = Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + ) + adapter = HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + def get_active_classification_threshold(request: Request) -> float: + """ + Get the active classification threshold from request parameters. + + Args: + request: The incoming request object + + Returns: + The classification threshold value, defaulting to 0 if not specified + """ # Look for a query param to filter by score classification_threshold = request.query_params.get("classification_threshold") diff --git a/config/api_router.py b/config/api_router.py index 0d45dcb50..d2a776fbc 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -28,6 +28,7 @@ router.register(r"occurrences", views.OccurrenceViewSet) router.register(r"taxa", views.TaxonViewSet) router.register(r"ml/algorithms", ml_views.AlgorithmViewSet) +router.register(r"ml/labels", ml_views.AlgorithmCategoryMapViewSet) router.register(r"ml/pipelines", ml_views.PipelineViewSet) router.register(r"ml/processing_services", ml_views.ProcessingServiceViewSet) router.register(r"classifications", views.ClassificationViewSet) diff --git a/config/settings/base.py b/config/settings/base.py index acd2baf5f..ebb86c76d 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -90,7 +90,7 @@ "drf_spectacular", "django_filters", "anymail", - "cachalot", + # "cachalot", ] LOCAL_APPS = [ @@ -242,6 +242,22 @@ SENDGRID_SANDBOX_MODE_IN_DEBUG = False SENDGRID_ECHO_TO_STDOUT = True +# CACHES +# ------------------------------------------------------------------------------ +# https://docs.djangoproject.com/en/dev/ref/settings/#caches +CACHES = { + "default": { + "BACKEND": "django_redis.cache.RedisCache", + "LOCATION": env("REDIS_URL", default=None), + "OPTIONS": { + "CLIENT_CLASS": "django_redis.client.DefaultClient", + # Mimicing memcache behavior. + # https://github.com/jazzband/django-redis#memcached-exceptions-behavior + "IGNORE_EXCEPTIONS": True, + }, + } +} + # ADMIN # ------------------------------------------------------------------------------ # Django Admin URL. diff --git a/config/settings/local.py b/config/settings/local.py index d20889160..c2f58afa0 100644 --- a/config/settings/local.py +++ b/config/settings/local.py @@ -23,15 +23,6 @@ "django", ] + env.list("DJANGO_ALLOWED_HOSTS", default=[]) -# CACHES -# ------------------------------------------------------------------------------ -# https://docs.djangoproject.com/en/dev/ref/settings/#caches -CACHES = { - "default": { - "BACKEND": "django.core.cache.backends.locmem.LocMemCache", - "LOCATION": "", - } -} # EMAIL # ------------------------------------------------------------------------------ @@ -60,6 +51,11 @@ INSTALLED_APPS = ["whitenoise.runserver_nostatic"] + INSTALLED_APPS # noqa: F405 +# Long queries can be a problem in development, this should stop them after 30s +database_options = DATABASES["default"].get("OPTIONS", {}) # noqa: F405 +database_options["options"] = "-c statement_timeout=30s" +DATABASES["default"]["OPTIONS"] = database_options # noqa: F405 + # django-debug-toolbar # ------------------------------------------------------------------------------ # https://django-debug-toolbar.readthedocs.io/en/latest/installation.html#prerequisites diff --git a/config/settings/production.py b/config/settings/production.py index 9f072b01c..e4abdf44e 100644 --- a/config/settings/production.py +++ b/config/settings/production.py @@ -22,20 +22,6 @@ # ------------------------------------------------------------------------------ DATABASES["default"]["CONN_MAX_AGE"] = env.int("CONN_MAX_AGE", default=60) # noqa: F405 -# CACHES -# ------------------------------------------------------------------------------ -CACHES = { - "default": { - "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": env("REDIS_URL"), - "OPTIONS": { - "CLIENT_CLASS": "django_redis.client.DefaultClient", - # Mimicing memcache behavior. - # https://github.com/jazzband/django-redis#memcached-exceptions-behavior - "IGNORE_EXCEPTIONS": True, - }, - } -} # SECURITY # ------------------------------------------------------------------------------ diff --git a/docker-compose.yml b/docker-compose.yml index 347c39248..478f85101 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,7 +23,6 @@ services: - postgres - redis - minio-init - - processing_service volumes: - .:/app:z env_file: @@ -140,10 +139,10 @@ services: - ./compose/local/minio/init.sh:/etc/minio/init.sh entrypoint: /etc/minio/init.sh - processing_service: + ml_backend: build: - context: ./processing_services/example + context: ./ml_backends/example volumes: - - ./processing_services/example/:/app:processing_service + - ./ml_backends/example/:/app ports: - "2005:2000" diff --git a/processing_services/example/api/algorithms.py b/processing_services/example/api/algorithms.py new file mode 100644 index 000000000..4636083aa --- /dev/null +++ b/processing_services/example/api/algorithms.py @@ -0,0 +1,119 @@ +from .schemas import AlgorithmCategoryMapResponse, AlgorithmConfigResponse + +RANDOM_DETECTOR = AlgorithmConfigResponse( + name="Random Detector", + key="random-detector", + task_type="detection", + description="Return bounding boxes at random locations within the image bounds.", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-detector", + category_map=None, +) + +CONSTANT_DETECTOR = AlgorithmConfigResponse( + name="Constant Detector", + key="constant-detector", + task_type="detection", + description="Return a fixed bounding box at a fixed location within the image bounds.", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/constant-detector", + category_map=None, +) + +RANDOM_BINARY_CLASSIFIER = AlgorithmConfigResponse( + name="Random binary classifier", + key="random-binary-classifier", + task_type="classification", + description="Randomly return a classification of 'Moth' or 'Not a moth'", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-binary-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + { + "index": 0, + "gbif_key": "1234", + "label": "Moth", + "source": "manual", + "taxon_rank": "SUPERFAMILY", + }, + { + "index": 1, + "gbif_key": "4543", + "label": "Not a moth", + "source": "manual", + "taxon_rank": "ORDER", + }, + ], + labels=["Moth", "Not a moth"], + version="v1", + description="A simple binary classifier", + uri="https://huggingface.co/RolnickLab/random-binary-classifier", + ), +) + +CONSTANT_CLASSIFIER = AlgorithmConfigResponse( + name="Constant classifier", + key="constant-classifier", + task_type="classification", + description="Always return a classification of 'Moth'", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/constant-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + { + "index": 0, + "gbif_key": "1234", + "label": "Moth", + "source": "manual", + "taxon_rank": "SUPERFAMILY", + } + ], + labels=["Moth"], + version="v1", + description="A classifier that always returns 'Moth'", + uri="https://huggingface.co/RolnickLab/constant-classifier", + ), +) + +RANDOM_SPECIES_CLASSIFIER = AlgorithmConfigResponse( + name="Random species classifier", + key="random-species-classifier", + task_type="classification", + description="A random species classifier", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-species-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + { + "index": 0, + "gbif_key": "1234", + "label": "Vanessa atalanta", + "source": "manual", + "taxon_rank": "SPECIES", + }, + { + "index": 1, + "gbif_key": "4543", + "label": "Vanessa cardui", + "source": "manual", + "taxon_rank": "SPECIES", + }, + { + "index": 2, + "gbif_key": "7890", + "label": "Vanessa itea", + "source": "manual", + "taxon_rank": "SPECIES", + }, + ], + labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"], + version="v1", + description="A simple species classifier", + uri="https://huggigface.co/RolnickLab/random-species-classifier", + ), +) diff --git a/processing_services/example/api/api.py b/processing_services/example/api/api.py index e7c0d27a8..e69de29bb 100644 --- a/processing_services/example/api/api.py +++ b/processing_services/example/api/api.py @@ -1,111 +0,0 @@ -""" -Fast API interface for processing images through the localization and classification pipelines. -""" - -import logging -import time - -import fastapi - -from .pipeline import ConstantPipeline, DummyPipeline -from .schemas import ( - AlgorithmConfig, - PipelineConfig, - PipelineRequest, - PipelineResponse, - SourceImage, - SourceImageResponse, -) - -logger = logging.getLogger(__name__) - -app = fastapi.FastAPI() - -pipeline1 = PipelineConfig( - name="ML Dummy Pipeline", - slug="dummy", - version=1, - algorithms=[ - AlgorithmConfig(name="Dummy Detector", key="1"), - AlgorithmConfig(name="Random Detector", key="2"), - AlgorithmConfig(name="Always Moth Classifier", key="3"), - ], -) - -pipeline2 = PipelineConfig( - name="ML Constant Pipeline", - slug="constant", - version=1, - algorithms=[ - AlgorithmConfig(name="Dummy Detector", key="1"), - AlgorithmConfig(name="Random Detector", key="2"), - AlgorithmConfig(name="Always Moth Classifier", key="3"), - ], -) - -pipelines = [pipeline1, pipeline2] - - -@app.get("/") -async def root(): - return fastapi.responses.RedirectResponse("/docs") - - -@app.get("/info", tags=["services"]) -async def info() -> list[PipelineConfig]: - return pipelines - - -# Check if the server is online -@app.get("/livez", tags=["health checks"]) -async def livez(): - return fastapi.responses.JSONResponse(status_code=200, content={"status": True}) - - -# Check if the pipelines are ready to process data -@app.get("/readyz", tags=["health checks"]) -async def readyz(): - if pipelines: - return fastapi.responses.JSONResponse( - status_code=200, content={"status": [pipeline.slug for pipeline in pipelines]} - ) - else: - return fastapi.responses.JSONResponse(status_code=503, content={"status": "pipelines unavailable"}) - - -@app.post("/process_images", tags=["services"]) -async def process(data: PipelineRequest) -> PipelineResponse: - pipeline_slug = data.pipeline - - source_image_results = [SourceImageResponse(**image.model_dump()) for image in data.source_images] - source_images = [SourceImage(**image.model_dump()) for image in data.source_images] - - start_time = time.time() - - if pipeline_slug == "constant": - pipeline = ConstantPipeline(source_images=source_images) # returns same detections - else: - pipeline = DummyPipeline(source_images=source_images) # returns random detections - - try: - results = pipeline.run() - except Exception as e: - logger.error(f"Error running pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - end_time = time.time() - seconds_elapsed = float(end_time - start_time) - - response = PipelineResponse( - pipeline=data.pipeline, - source_images=source_image_results, - detections=results, - total_time=seconds_elapsed, - ) - return response - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=2000) diff --git a/processing_services/example/api/pipeline.py b/processing_services/example/api/pipeline.py deleted file mode 100644 index a19076949..000000000 --- a/processing_services/example/api/pipeline.py +++ /dev/null @@ -1,134 +0,0 @@ -import datetime -import math -import random - -from .schemas import BoundingBox, Classification, Detection, SourceImage - - -def make_random_bbox(source_image_width: int, source_image_height: int): - # Make a random box. - # Ensure that the box is within the image bounds and the bottom right corner is greater than the top left corner. - x1 = random.randint(0, source_image_width) - x2 = random.randint(0, source_image_width) - y1 = random.randint(0, source_image_height) - y2 = random.randint(0, source_image_height) - - return BoundingBox( - x1=min(x1, x2), - y1=min(y1, y2), - x2=max(x1, x2), - y2=max(y1, y2), - ) - - -def generate_adaptive_grid_bounding_boxes(image_width: int, image_height: int, num_boxes: int) -> list[BoundingBox]: - # Estimate grid size based on num_boxes - grid_size: int = math.ceil(math.sqrt(num_boxes)) - - cell_width: float = image_width / grid_size - cell_height: float = image_height / grid_size - - boxes: list[BoundingBox] = [] - - for _ in range(num_boxes): - # Select a random cell - row: int = random.randint(0, grid_size - 1) - col: int = random.randint(0, grid_size - 1) - - # Calculate the cell's boundaries - cell_x1: float = col * cell_width - cell_y1: float = row * cell_height - - # Generate a random box within the cell - # Ensure the box is between 50% and 100% of the cell size - box_width: float = random.uniform(cell_width * 0.5, cell_width) - box_height: float = random.uniform(cell_height * 0.5, cell_height) - - x1: float = cell_x1 + random.uniform(0, cell_width - box_width) - y1: float = cell_y1 + random.uniform(0, cell_height - box_height) - x2: float = x1 + box_width - y2: float = y1 + box_height - - boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2)) - - return boxes - - -def make_fake_detections(source_image: SourceImage, num_detections: int = 10): - source_image.open(raise_exception=True) - assert source_image.width is not None and source_image.height is not None - bboxes = generate_adaptive_grid_bounding_boxes(source_image.width, source_image.height, num_detections) - timestamp = datetime.datetime.now() - - return [ - Detection( - source_image_id=source_image.id, - bbox=bbox, - timestamp=timestamp, - algorithm="Random Detector", - classifications=[ - Classification( - classification="moth", - labels=["moth"], - scores=[random.random()], - timestamp=timestamp, - algorithm="Always Moth Classifier", - ) - ], - ) - for bbox in bboxes - ] - - -def make_constant_detections(source_image: SourceImage, num_detections: int = 10): - source_image.open(raise_exception=True) - assert source_image.width is not None and source_image.height is not None - - # Define a fixed bounding box size and position relative to image size - box_width, box_height = source_image.width // 4, source_image.height // 4 - start_x, start_y = source_image.width // 8, source_image.height // 8 - bboxes = [BoundingBox(x1=start_x, y1=start_y, x2=start_x + box_width, y2=start_y + box_height)] - timestamp = datetime.datetime.now() - - return [ - Detection( - source_image_id=source_image.id, - bbox=bbox, - timestamp=timestamp, - algorithm="Fixed Detector", - classifications=[ - Classification( - classification="moth", - labels=["moth"], - scores=[0.9], # Constant score for each detection - timestamp=timestamp, - algorithm="Always Moth Classifier", - ) - ], - ) - for bbox in bboxes - ] - - -class DummyPipeline: - source_images: list[SourceImage] - - def __init__(self, source_images: list[SourceImage]): - self.source_images = source_images - - def run(self) -> list[Detection]: - results = [make_fake_detections(source_image) for source_image in self.source_images] - # Flatten the list of lists - return [item for sublist in results for item in sublist] - - -class ConstantPipeline: - source_images: list[SourceImage] - - def __init__(self, source_images: list[SourceImage]): - self.source_images = source_images - - def run(self) -> list[Detection]: - results = [make_constant_detections(source_image) for source_image in self.source_images] - # Flatten the list of lists - return [item for sublist in results for item in sublist] diff --git a/processing_services/example/api/pipelines.py b/processing_services/example/api/pipelines.py new file mode 100644 index 000000000..0d955b417 --- /dev/null +++ b/processing_services/example/api/pipelines.py @@ -0,0 +1,216 @@ +import datetime +import math +import random + +from . import algorithms +from .schemas import ( + AlgorithmConfigResponse, + AlgorithmReference, + BoundingBox, + ClassificationResponse, + DetectionResponse, + PipelineConfigResponse, + SourceImage, +) + + +def make_random_bbox(source_image_width: int, source_image_height: int): + # Make a random box. + # Ensure that the box is within the image bounds and the bottom right corner is greater than the + # top left corner. + x1 = random.randint(0, source_image_width) + x2 = random.randint(0, source_image_width) + y1 = random.randint(0, source_image_height) + y2 = random.randint(0, source_image_height) + + return BoundingBox( + x1=min(x1, x2), + y1=min(y1, y2), + x2=max(x1, x2), + y2=max(y1, y2), + ) + + +def generate_adaptive_grid_bounding_boxes(image_width: int, image_height: int, num_boxes: int) -> list[BoundingBox]: + # Estimate grid size based on num_boxes + grid_size: int = math.ceil(math.sqrt(num_boxes)) + + cell_width: float = image_width / grid_size + cell_height: float = image_height / grid_size + + boxes: list[BoundingBox] = [] + + for _ in range(num_boxes): + # Select a random cell + row: int = random.randint(0, grid_size - 1) + col: int = random.randint(0, grid_size - 1) + + # Calculate the cell's boundaries + cell_x1: float = col * cell_width + cell_y1: float = row * cell_height + + # Generate a random box within the cell + # Ensure the box is between 50% and 100% of the cell size + box_width: float = random.uniform(cell_width * 0.5, cell_width) + box_height: float = random.uniform(cell_height * 0.5, cell_height) + + x1: float = cell_x1 + random.uniform(0, cell_width - box_width) + y1: float = cell_y1 + random.uniform(0, cell_height - box_height) + x2: float = x1 + box_width + y2: float = y1 + box_height + + boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2)) + + return boxes + + +def make_random_prediction( + algorithm: AlgorithmConfigResponse, + terminal: bool = True, + max_labels: int = 2, +) -> ClassificationResponse: + assert algorithm.category_map is not None + category_labels = algorithm.category_map.labels + logits = [random.random() for _ in category_labels] + softmax = [math.exp(logit) / sum([math.exp(logit) for logit in logits]) for logit in logits] + top_class = category_labels[softmax.index(max(softmax))] + return ClassificationResponse( + classification=top_class, + labels=category_labels if len(category_labels) <= max_labels else None, + scores=softmax, + logits=logits, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference(name=algorithm.name, key=algorithm.key), + terminal=terminal, + ) + + +def make_random_detections(source_image: SourceImage, num_detections: int = 10): + source_image.open(raise_exception=True) + assert source_image.width is not None and source_image.height is not None + bboxes = generate_adaptive_grid_bounding_boxes(source_image.width, source_image.height, num_detections) + timestamp = datetime.datetime.now() + + return [ + DetectionResponse( + source_image_id=source_image.id, + bbox=bbox, + timestamp=timestamp, + algorithm=AlgorithmReference( + name=algorithms.RANDOM_DETECTOR.name, + key=algorithms.RANDOM_DETECTOR.key, + ), + classifications=[ + make_random_prediction( + algorithm=algorithms.RANDOM_BINARY_CLASSIFIER, + terminal=False, + ), + make_random_prediction( + algorithm=algorithms.RANDOM_SPECIES_CLASSIFIER, + terminal=True, + ), + ], + ) + for bbox in bboxes + ] + + +def make_constant_detections(source_image: SourceImage, num_detections: int = 10): + source_image.open(raise_exception=True) + assert source_image.width is not None and source_image.height is not None + + # Define a fixed bounding box size and position relative to image size + box_width, box_height = source_image.width // 4, source_image.height // 4 + start_x, start_y = source_image.width // 8, source_image.height // 8 + bboxes = [BoundingBox(x1=start_x, y1=start_y, x2=start_x + box_width, y2=start_y + box_height)] + timestamp = datetime.datetime.now() + + assert algorithms.CONSTANT_CLASSIFIER.category_map is not None + labels = algorithms.CONSTANT_CLASSIFIER.category_map.labels + + return [ + DetectionResponse( + source_image_id=source_image.id, + bbox=bbox, + timestamp=timestamp, + algorithm=AlgorithmReference(name=algorithms.CONSTANT_DETECTOR.name, key=algorithms.CONSTANT_DETECTOR.key), + classifications=[ + ClassificationResponse( + classification=labels[0], + labels=labels, + scores=[0.9], # Constant score for each detection + timestamp=timestamp, + algorithm=AlgorithmReference( + name=algorithms.CONSTANT_CLASSIFIER.name, key=algorithms.CONSTANT_CLASSIFIER.key + ), + ) + ], + ) + for bbox in bboxes + ] + + +class Pipeline: + source_images: list[SourceImage] + + def __init__(self, source_images: list[SourceImage]): + self.source_images = source_images + + def run(self) -> list[DetectionResponse]: + raise NotImplementedError("Subclasses must implement the run method") + + config = PipelineConfigResponse( + name="Base Pipeline", + slug="base", + description="A base class for all pipelines.", + version=1, + algorithms=[], + ) + + +class RandomPipeline(Pipeline): + """ + A pipeline that returns detections in random positions within the image bounds with random classifications. + """ + + def run(self) -> list[DetectionResponse]: + results = [make_random_detections(source_image) for source_image in self.source_images] + # Flatten the list of lists + return [item for sublist in results for item in sublist] + + config = PipelineConfigResponse( + name="Random Pipeline", + slug="random", + description=( + "A pipeline that returns detections in random positions within the image bounds " + "with random classifications." + ), + version=1, + algorithms=[ + algorithms.RANDOM_DETECTOR, + algorithms.RANDOM_BINARY_CLASSIFIER, + algorithms.RANDOM_SPECIES_CLASSIFIER, + ], + ) + + +class ConstantPipeline(Pipeline): + """ + A pipeline that always returns a detection in the same position with a fixed classification. + """ + + def run(self) -> list[DetectionResponse]: + results = [make_constant_detections(source_image) for source_image in self.source_images] + # Flatten the list of lists + return [item for sublist in results for item in sublist] + + config = PipelineConfigResponse( + name="Constant Pipeline", + slug="constant", + description="A pipeline that always returns a detection in the same position with a fixed classification.", + version=1, + algorithms=[ + algorithms.CONSTANT_DETECTOR, + algorithms.CONSTANT_CLASSIFIER, + ], + ) diff --git a/processing_services/example/api/schemas.py b/processing_services/example/api/schemas.py index adb4a16ee..3396e1337 100644 --- a/processing_services/example/api/schemas.py +++ b/processing_services/example/api/schemas.py @@ -29,6 +29,9 @@ def to_string(self): def to_path(self): return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]]) + def to_tuple(self): + return (self.x1, self.y1, self.x2, self.y2) + class SourceImage(pydantic.BaseModel): model_config = pydantic.ConfigDict(extra="ignore", arbitrary_types_allowed=True) @@ -65,34 +68,52 @@ def open(self, raise_exception=False) -> PIL.Image.Image | None: return self._pil -class Classification(pydantic.BaseModel): +class AlgorithmReference(pydantic.BaseModel): + name: str + key: str + + +class ClassificationResponse(pydantic.BaseModel): classification: str - labels: list[str] = [] - scores: list[float] = [] + labels: list[str] | None = pydantic.Field( + default=None, + description=( + "A list of all possible labels for the model, in the correct order. " + "Omitted if the model has too many labels to include for each classification in the response. " + "Use the category map from the algorithm to get the full list of labels and metadata." + ), + ) + scores: list[float] = pydantic.Field( + default_factory=list, + description="The calibrated probabilities for each class label, most commonly the softmax output.", + ) + logits: list[float] = pydantic.Field( + default_factory=list, + description="The raw logits output by the model, before any calibration or normalization.", + ) inference_time: float | None = None - algorithm: str + algorithm: AlgorithmReference terminal: bool = True timestamp: datetime.datetime -class Detection(pydantic.BaseModel): +class DetectionResponse(pydantic.BaseModel): source_image_id: str bbox: BoundingBox inference_time: float | None = None - algorithm: str + algorithm: AlgorithmReference timestamp: datetime.datetime crop_image_url: str | None = None - classifications: list[Classification] = [] + classifications: list[ClassificationResponse] = [] class SourceImageRequest(pydantic.BaseModel): - # @TODO bring over new SourceImage & b64 validation from the lepsAI repo + model_config = pydantic.ConfigDict(extra="ignore") + id: str url: str # b64: str | None = None - - class Config: - extra = "ignore" + # @TODO bring over new SourceImage & b64 validation from the lepsAI repo class SourceImageResponse(pydantic.BaseModel): @@ -102,7 +123,68 @@ class SourceImageResponse(pydantic.BaseModel): url: str -PipelineChoice = typing.Literal["dummy", "constant"] +class AlgorithmCategoryMapResponse(pydantic.BaseModel): + data: list[dict] = pydantic.Field( + default_factory=dict, + description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.", + examples=[ + [ + {"label": "Moth", "index": 0, "gbif_key": 1234}, + {"label": "Not a moth", "index": 1, "gbif_key": 5678}, + ] + ], + ) + labels: list[str] = pydantic.Field( + default_factory=list, + description="A simple list of string labels, in the correct index order used by the model.", + examples=[["Moth", "Not a moth"]], + ) + version: str | None = pydantic.Field( + default=None, + description="The version of the category map. Can be a descriptive string or a version number.", + examples=["LepNet2021-with-2023-mods"], + ) + description: str | None = pydantic.Field( + default=None, + description="A description of the category map used to train. e.g. source, purpose and modifications.", + examples=["LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations."], + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the category map file, could be a public web URL or object store path.", + ) + + +class AlgorithmConfigResponse(pydantic.BaseModel): + name: str + key: str = pydantic.Field( + description=("A unique key for an algorithm to lookup the category map (class list) and other metadata."), + ) + description: str | None = None + task_type: str | None = pydantic.Field( + default=None, + description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.", + examples=["detection", "classification", "segmentation", "embedding"], + ) + version: int = pydantic.Field( + default=1, + description="A sortable version number for the model. Increment this number when the model is updated.", + ) + version_name: str | None = pydantic.Field( + default=None, + description="A complete version name e.g. '2021-01-01', 'LepNet2021'.", + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the weights or model details, could be a public web URL or object store path.", + ) + category_map: AlgorithmCategoryMapResponse | None = None + + class Config: + extra = "ignore" + + +PipelineChoice = typing.Literal["random", "constant"] class PipelineRequest(pydantic.BaseModel): @@ -117,18 +199,23 @@ class Config: "source_images": [ { "id": "123", - "url": "https://example.com/image.jpg", + "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", } ], } } -class PipelineResponse(pydantic.BaseModel): +class PipelineResultsResponse(pydantic.BaseModel): pipeline: PipelineChoice + algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field( + default_factory=dict, + description="A dictionary of all algorithms used in the pipeline, including their class list and other " + "metadata, keyed by the algorithm key.", + ) total_time: float source_images: list[SourceImageResponse] - detections: list[Detection] + detections: list[DetectionResponse] class PipelineStageParam(pydantic.BaseModel): @@ -148,17 +235,34 @@ class PipelineStage(pydantic.BaseModel): description: str | None = None -class AlgorithmConfig(pydantic.BaseModel): - name: str - key: str - - -class PipelineConfig(pydantic.BaseModel): - """A configurable pipeline.""" +class PipelineConfigResponse(pydantic.BaseModel): + """Details about a pipeline, its algorithms and category maps.""" name: str slug: str version: int description: str | None = None - algorithms: list[AlgorithmConfig] = [] + algorithms: list[AlgorithmConfigResponse] = [] stages: list[PipelineStage] = [] + + +class ProcessingServiceInfoResponse(pydantic.BaseModel): + """Information about the processing service.""" + + name: str = pydantic.Field(example="Mila Research Lab - Moth AI Services") + description: str | None = pydantic.Field( + default=None, + examples=["Algorithms developed by the Mila Research Lab for analysis of moth images."], + ) + pipelines: list[PipelineConfigResponse] = pydantic.Field( + default=list, + examples=[ + [ + PipelineConfigResponse(name="Random Pipeline", slug="random", version=1, algorithms=[]), + ] + ], + ) + # algorithms: list[AlgorithmConfigResponse] = pydantic.Field( + # default=list, + # examples=[RANDOM_BINARY_CLASSIFIER], + # ) diff --git a/processing_services/example/api/test.py b/processing_services/example/api/test.py index 717367314..187298094 100644 --- a/processing_services/example/api/test.py +++ b/processing_services/example/api/test.py @@ -3,7 +3,7 @@ from fastapi.testclient import TestClient from .api import app -from .pipeline import DummyPipeline +from .pipelines import RandomPipeline from .schemas import PipelineRequest, SourceImage, SourceImageRequest @@ -13,7 +13,7 @@ def test_dummy_pipeline(self): SourceImage(id="1", url="https://example.com/image1.jpg"), SourceImage(id="2", url="https://example.com/image2.jpg"), ] - pipeline = DummyPipeline(source_images=source_images) + pipeline = RandomPipeline(source_images=source_images) detections = pipeline.run() self.assertEqual(len(detections), 20) @@ -45,7 +45,7 @@ def test_process(self): ] source_image_requests = [SourceImageRequest(**image.dict()) for image in source_images] request = PipelineRequest(pipeline="random", source_images=source_image_requests) - response = self.client.post("/pipeline/process", json=request.dict()) + response = self.client.post("/process", json=request.dict()) self.assertEqual(response.status_code, 200) data = response.json() diff --git a/processing_services/example/api/utils.py b/processing_services/example/api/utils.py index c70120215..119723ae5 100644 --- a/processing_services/example/api/utils.py +++ b/processing_services/example/api/utils.py @@ -5,15 +5,23 @@ import pathlib import re import tempfile -import urllib.request from urllib.parse import urlparse import PIL.Image +import PIL.ImageFile +import requests logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True + +# This is polite and required by some hosts +# see: https://foundation.wikimedia.org/wiki/Policy:User-Agent_policy +USER_AGENT = "AntennaInsectDataPlatform/1.0 (https://insectai.org)" + + def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: """ Fetch a file from a URL or local path. If the path is a URL, download the file. @@ -44,17 +52,20 @@ def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: else: logger.info(f"Downloading {path_or_url} to {local_filepath}") - try: - resulting_filepath, _headers = urllib.request.urlretrieve(url=path_or_url, filename=local_filepath) - except Exception as e: - raise Exception(f"Could not retrieve {path_or_url}: {e}") + headers = {"User-Agent": USER_AGENT} + response = requests.get(path_or_url, stream=True, headers=headers) + response.raise_for_status() # Raise an exception for HTTP errors + + with open(local_filepath, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) - resulting_filepath = pathlib.Path(resulting_filepath) + resulting_filepath = pathlib.Path(local_filepath).resolve() logger.info(f"Downloaded to {resulting_filepath}") return resulting_filepath -def open_image(fp: str | bytes | pathlib.Path, raise_exception: bool = True) -> PIL.Image.Image | None: +def open_image(fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True) -> PIL.Image.Image | None: """ Wrapper from PIL.Image.open that handles errors and converts to RGB. """ diff --git a/ui/src/data-services/models/algorithm.ts b/ui/src/data-services/models/algorithm.ts index 1ab7c0f55..d27f2e02c 100644 --- a/ui/src/data-services/models/algorithm.ts +++ b/ui/src/data-services/models/algorithm.ts @@ -27,8 +27,12 @@ export class Algorithm { return this._algorithm.name } - get url(): string { - return this._algorithm.url + get key(): string { + return this._algorithm.key + } + + get uri(): string { + return this._algorithm.uri } get updatedAt(): string | undefined { diff --git a/ui/src/data-services/models/job-details.ts b/ui/src/data-services/models/job-details.ts index 33180d082..e993cf91e 100644 --- a/ui/src/data-services/models/job-details.ts +++ b/ui/src/data-services/models/job-details.ts @@ -16,11 +16,11 @@ export class JobDetails extends Job { } get errors(): string[] { - return this._job.progress.errors ?? [] + return this._job.logs.stderr ?? [] } get logs(): string[] { - return this._job.progress.logs ?? [] + return this._job.logs.stdout ?? [] } get stages(): { diff --git a/ui/src/data-services/models/occurrence-details.ts b/ui/src/data-services/models/occurrence-details.ts index 7376f1971..516b6ab5b 100644 --- a/ui/src/data-services/models/occurrence-details.ts +++ b/ui/src/data-services/models/occurrence-details.ts @@ -14,6 +14,8 @@ export interface Identification { taxon: Taxon comment?: string algorithm?: Algorithm + score?: number + terminal?: boolean userPermissions: UserPermission[] createdAt: string } @@ -30,6 +32,7 @@ export interface HumanIdentification extends Identification { export interface MachinePrediction extends Identification { algorithm: Algorithm score: number + terminal: boolean } export class OccurrenceDetails extends Occurrence { @@ -83,6 +86,7 @@ export class OccurrenceDetails extends Occurrence { overridden, taxon, score: p.score, + terminal: p.terminal, algorithm: p.algorithm, userPermissions: p.user_permissions, createdAt: p.created_at, diff --git a/ui/src/data-services/models/pipeline.ts b/ui/src/data-services/models/pipeline.ts index aa5f62701..537948b82 100644 --- a/ui/src/data-services/models/pipeline.ts +++ b/ui/src/data-services/models/pipeline.ts @@ -96,6 +96,11 @@ export class Pipeline { get processingServicesOnlineLastChecked(): string | undefined { const processingServices = this._pipeline.processing_services + + if (!processingServices.length) { + return undefined + } + const last_checked_times = [] for (const processingService of processingServices) { last_checked_times.push( diff --git a/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx b/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx index 19f8b9111..fde4ff9e7 100644 --- a/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx +++ b/ui/src/design-system/components/identification/identification-summary/identification-summary.tsx @@ -49,6 +49,22 @@ export const IdentificationSummary = ({ {identification.algorithm && ( )} + + {identification.score && ( +
+ {`${identification.score.toPrecision(4)}`} +
+ )} + {identification.terminal !== undefined && ( +
+ {identification.terminal + ? translate(STRING.TERMINAL_CLASSIFICATION) + : translate(STRING.INTERMEDIATE_CLASSIFICATION)} +
+ )} +
+ {formattedTime} +
) } diff --git a/ui/src/pages/occurrence-details/identification-card/identification-card.tsx b/ui/src/pages/occurrence-details/identification-card/identification-card.tsx index 607bfd110..36c0d3521 100644 --- a/ui/src/pages/occurrence-details/identification-card/identification-card.tsx +++ b/ui/src/pages/occurrence-details/identification-card/identification-card.tsx @@ -40,7 +40,7 @@ export const IdentificationCard = ({ const canDelete = identification.userPermissions.includes( UserPermission.Update ) - const showAgree = !byCurrentUser && canAgree && !identification.overridden + const showAgree = !byCurrentUser && canAgree const showDelete = byCurrentUser && canDelete if (deleteIdOpen) { diff --git a/ui/src/pages/overview/pipelines/pipelines-columns.tsx b/ui/src/pages/overview/pipelines/pipelines-columns.tsx index b4a5cc978..c1a064594 100644 --- a/ui/src/pages/overview/pipelines/pipelines-columns.tsx +++ b/ui/src/pages/overview/pipelines/pipelines-columns.tsx @@ -43,7 +43,7 @@ export const columns: (projectId: string) => TableColumn[] = () => [ }, { id: 'processing-services-online-last-checked', - name: 'Processing Services Online Last Checked', + name: 'Status Last Checked', sortField: 'processing_services_online_last_checked', renderCell: (item: Pipeline) => ( diff --git a/ui/src/pages/pipeline-details/pipeline-algorithms.tsx b/ui/src/pages/pipeline-details/pipeline-algorithms.tsx index cc479dcfe..8aeb96726 100644 --- a/ui/src/pages/pipeline-details/pipeline-algorithms.tsx +++ b/ui/src/pages/pipeline-details/pipeline-algorithms.tsx @@ -39,6 +39,11 @@ export const columns: TableColumn[] = [ name: translate(STRING.FIELD_LABEL_UPDATED_AT), renderCell: (item: Algorithm) => , }, + { + id: 'key', + name: translate(STRING.FIELD_LABEL_SLUG), + renderCell: (item: Algorithm) => , + }, ] export const PipelineAlgorithms = ({ pipeline }: { pipeline: Pipeline }) => ( diff --git a/ui/src/pages/pipeline-details/pipeline-details-dialog.tsx b/ui/src/pages/pipeline-details/pipeline-details-dialog.tsx index 4295f9c8f..bd5d5940f 100644 --- a/ui/src/pages/pipeline-details/pipeline-details-dialog.tsx +++ b/ui/src/pages/pipeline-details/pipeline-details-dialog.tsx @@ -64,8 +64,8 @@ const PipelineDetailsContent = ({