diff --git a/trapdata/db/base.py b/trapdata/db/base.py index 58a3c8a6..29f450fe 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -3,10 +3,12 @@ import time from typing import Generator +import alembic +import alembic.command import sqlalchemy as sa import sqlalchemy.exc -from alembic import command as alembic from alembic.config import Config +from alembic.script import ScriptDirectory from rich import print from sqlalchemy import orm @@ -74,19 +76,44 @@ def create_db(db_path: DatabaseURL) -> None: Base.metadata.create_all(db, checkfirst=True) alembic_cfg = get_alembic_config(db_path) - alembic.stamp(alembic_cfg, "head") + alembic.command.stamp(alembic_cfg, "head") def migrate(db_path: DatabaseURL) -> None: """ - Run database migrations. - - # @TODO See this post for a more complete implementation - # https://pawamoy.github.io/posts/testing-fastapi-ormar-alembic-apps/ + Run database migrations with better error handling and verification. """ logger.debug("Running any database migrations if necessary") alembic_cfg = get_alembic_config(db_path) - alembic.upgrade(alembic_cfg, "head") + + try: + # Check current state first + current_head = alembic.command.current(alembic_cfg) + script_dir = ScriptDirectory.from_config(alembic_cfg) + target_head = script_dir.get_current_head() + + if current_head != target_head: + logger.info(f"Upgrading from {current_head} to {target_head}") + alembic.command.upgrade(alembic_cfg, "head") + logger.info("Migration completed successfully") + else: + logger.debug("Database already at target revision") + + # Verify the migration actually worked + logger.debug("Verifying database schema consistency") + alembic.command.check(alembic_cfg) + logger.debug("Database schema verification passed") + + except Exception as e: + logger.error(f"Migration failed: {e}") + # Check if we're in an inconsistent state + try: + alembic.command.check(alembic_cfg) + logger.warning("Migration failed but database schema appears consistent") + except Exception as check_error: + logger.error(f"Database is in inconsistent state: {check_error}") + logger.error("Manual intervention may be required to fix migration state") + raise def get_db(db_path, create=False, update=False): diff --git a/trapdata/db/migrations/versions/1754607024_68f8b8fe793a_new_column_for_saving_logits_to_.py b/trapdata/db/migrations/versions/1754607024_68f8b8fe793a_new_column_for_saving_logits_to_.py new file mode 100644 index 00000000..966df90c --- /dev/null +++ b/trapdata/db/migrations/versions/1754607024_68f8b8fe793a_new_column_for_saving_logits_to_.py @@ -0,0 +1,27 @@ +"""New column for saving logits to detections + +Revision ID: 68f8b8fe793a +Revises: 1544478c3031 +Create Date: 2025-08-07 15:50:24.447765 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "68f8b8fe793a" +down_revision = "1544478c3031" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("detections", sa.Column("logits", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("detections", "logits") + # ### end Alembic commands ### diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 1432cfcc..0f216858 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -23,7 +23,8 @@ class DetectionListItem(BaseModel): area_pixels: Optional[float] last_detected: Optional[datetime.datetime] label: Optional[str] - score: Optional[int] + score: Optional[float] + # logits: Optional[list[float]] model_name: Optional[str] in_queue: bool notes: Optional[str] @@ -41,8 +42,9 @@ class DetectionDetail(DetectionListItem): sequence_cost: Optional[float] source_image_path: Optional[pathlib.Path] timestamp: Optional[str] - bbox_center: Optional[tuple[int, int]] + bbox_center: Optional[tuple[float, float]] area_pixels: Optional[int] + logits: Optional[list[float]] class DetectedObject(db.Base): @@ -75,6 +77,7 @@ class DetectedObject(db.Base): sequence_frame = sa.Column(sa.Integer) sequence_previous_id = sa.Column(sa.Integer) sequence_previous_cost = sa.Column(sa.Float) + logits = sa.Column(sa.JSON) cnn_features = sa.Column(sa.JSON) # @TODO add updated & created timestamps to all db models @@ -288,6 +291,7 @@ def report_data(self) -> DetectionDetail: last_detected=self.last_detected, notes=self.notes, in_queue=self.in_queue, + logits=self.logits, ) def report_data_simple(self): diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index c3e643bb..0952d8d5 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -184,7 +184,7 @@ def get_transforms(self): ] ) - def post_process_batch(self, output): + def post_process_batch(self, output: torch.Tensor) -> list[tuple[str, float, list]]: predictions = torch.nn.functional.softmax(output, dim=1) predictions = predictions.cpu().numpy() @@ -192,9 +192,10 @@ def post_process_batch(self, output): labels = [self.category_map[cat] for cat in categories] scores = predictions.max(axis=1).astype(float) - result = list(zip(labels, scores)) - logger.debug(f"Post-processing result batch: {result}") - return result + logits = output.cpu().detach().numpy().tolist() + result_per_image = list(zip(labels, scores, logits)) + logger.debug(f"Post-processing result batch: {result_per_image}") + return result_per_image class Resnet50ClassifierLowRes(Resnet50Classifier): @@ -249,7 +250,13 @@ def get_dataset(self): ) return dataset - def save_results(self, object_ids, batch_output, *args, **kwargs): + def save_results( + self, + object_ids, + batch_output: list[tuple[str, float, list]], + *args, + **kwargs, + ): # Here we are saving the moth/non-moth labels classified_objects_data = [ { @@ -258,7 +265,7 @@ def save_results(self, object_ids, batch_output, *args, **kwargs): "in_queue": True if label == self.positive_binary_label else False, "model_name": self.name, } - for label, score in batch_output + for label, score, _logits in batch_output ] save_classified_objects(self.db_path, object_ids, classified_objects_data) @@ -302,16 +309,24 @@ def get_dataset(self): ) return dataset - def save_results(self, object_ids, batch_output, *args, **kwargs): + def save_results( + self, + object_ids, + batch_output: tuple[list[tuple[str, float]], list], + *args, + **kwargs, + ): # Here we are saving the specific taxon labels classified_objects_data = [ { "specific_label": label, - "specific_label_score": score, + "specific_label_score": top_score, + "logits": logits, "model_name": self.name, - "in_queue": True, # Put back in queue for the feature extractor & tracking + # Put back in queue for the feature extractor & tracking + "in_queue": True, } - for label, score in batch_output + for label, top_score, logits in batch_output ] save_classified_objects(self.db_path, object_ids, classified_objects_data)