Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
* [x] Export Trained Models to ONNX
* [x] ONNX Model with Integrated Preprocessing
* **🛠️ Utilities & Demos**
* [x] Extensive metric logging for debugging
* [x] Cross-Platform Support (Linux, macOS, Windows)
* [x] Pixi Environment Management Integration
* [x] Interactive Gradio Demo Script
Expand Down
270 changes: 246 additions & 24 deletions src/deimkit/engine/solver/det_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import sys
import math
from typing import Iterable
from typing import Iterable, Dict, Optional, List, Tuple
from tqdm import tqdm

import torch
Expand All @@ -21,6 +21,9 @@
from ..data import CocoEvaluator
from ..misc import MetricLogger, SmoothedValue, dist_utils

import matplotlib.pyplot as plt
import numpy as np


def train_one_epoch(self_lr_scheduler, lr_scheduler, model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
Expand Down Expand Up @@ -132,9 +135,170 @@ def train_one_epoch(self_lr_scheduler, lr_scheduler, model: torch.nn.Module, cri
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def create_pr_curve_plot(
precisions: np.ndarray,
recalls: np.ndarray,
labels: List[str],
colors: List[str],
title: str,
figsize: Tuple[int, int] = (10, 10)
) -> plt.Figure:
"""
Create a precision-recall curve plot.

Args:
precisions: Array of precision values for each curve, shape (num_curves, num_points)
recalls: Array of recall values, shape (num_points,)
labels: List of labels for each curve
colors: List of colors for each curve
title: Plot title
figsize: Figure size in inches

Returns:
matplotlib.figure.Figure: The created figure
"""
fig, ax = plt.subplots(figsize=figsize)

for precision, label, color in zip(precisions, labels, colors):
ax.plot(recalls, precision, color=color, label=label, linewidth=2)

ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title(title, fontsize=14)
ax.grid(True)
ax.legend(loc='lower left')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])

return fig

def get_precision_recall_data(
eval_result,
iou_thresh: Optional[float] = None,
area_idx: Optional[int] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""
Extract precision-recall data from COCO evaluation results.

Args:
eval_result: COCO evaluation result object
iou_thresh: IoU threshold value. If None, average over all IoU thresholds
area_idx: Area index (0: all, 1: small, 2: medium, 3: large). If None, use all areas

Returns:
Tuple of (precision array, recall array)
"""
recalls = eval_result.params.recThrs

if iou_thresh is not None:
iou_index = eval_result.params.iouThrs == iou_thresh
# Make sure to get the same shape as recalls
precision = eval_result.eval['precision'][iou_index, :, :, 0, -1].mean(axis=1).squeeze()
else:
# Average over IoU thresholds and categories
precision = eval_result.eval['precision'][:, :, :, 0, -1].mean(axis=2).mean(axis=0)

if area_idx is not None:
precision = eval_result.eval['precision'][:, :, :, area_idx, -1].mean(axis=2).mean(axis=0)

# Ensure precision has the same shape as recalls
if precision.shape != recalls.shape:
precision = np.full_like(recalls, np.nan)

return precision, recalls

def log_pr_curves(
coco_evaluator: CocoEvaluator,
writer: SummaryWriter,
global_step: int,
iou_types: List[str]
) -> None:
"""
Log precision-recall curves to TensorBoard.
"""
if writer is None or not dist_utils.is_main_process():
return

iou_thresholds = [0.5, 0.75]
area_labels = ['all', 'small', 'medium', 'large']

for iou_type in iou_types:
eval_result = coco_evaluator.coco_eval[iou_type]
recalls = eval_result.params.recThrs

# IoU threshold based curves
precisions = []
labels = []
colors = ['b', 'r', 'g'] # Colors for IoU=0.5, 0.75, and mean

# Get PR curves for specific IoU thresholds
for iou_thresh in iou_thresholds:
precision, _ = get_precision_recall_data(eval_result, iou_thresh=iou_thresh)
precisions.append(precision)
labels.append(f'IoU={iou_thresh:.2f}')

# Add mean PR curve (IoU=0.50:0.95)
precision, _ = get_precision_recall_data(eval_result)
precisions.append(precision)
labels.append('IoU=0.50:0.95')

# Stack precisions into a 2D array
precisions = np.stack(precisions)

# Create and log IoU threshold based plot
fig = create_pr_curve_plot(
precisions,
recalls,
labels,
colors,
f'Precision-Recall Curves ({iou_type})'
)
writer.add_figure(f'metrics-PR/{iou_type}/precision_recall_curve', fig, global_step)
plt.close(fig)

# Area based curves
precisions = []
colors = ['g', 'b', 'r', 'c']

# Get PR curves for different areas
for area_idx in range(4):
precision, _ = get_precision_recall_data(eval_result, area_idx=area_idx)
precisions.append(precision)

# Stack precisions into a 2D array
precisions = np.stack(precisions)

# Create and log area based plot
fig = create_pr_curve_plot(
precisions,
recalls,
[f'area={label}' for label in area_labels],
colors,
f'Precision-Recall Curves by Area ({iou_type})'
)
writer.add_figure(f'metrics-PR/{iou_type}/precision_recall_curve_by_area', fig, global_step)
plt.close(fig)

def calculate_f1_score(precision: float, recall: float) -> float:
"""
Calculate F1 score from precision and recall values.

Args:
precision: Precision value (AP)
recall: Recall value (AR)

Returns:
float: F1 score if valid, float('nan') if invalid
"""
if precision <= 0 or recall <= 0:
return float('nan')

return 2 * (precision * recall) / (precision + recall)

@torch.no_grad()
def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessor, data_loader, coco_evaluator: CocoEvaluator,
device, writer: SummaryWriter = None, global_step: int = None):
def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessor, data_loader,
coco_evaluator: CocoEvaluator, device, writer: Optional[SummaryWriter] = None,
global_step: Optional[int] = None):
model.eval()
criterion.eval()
coco_evaluator.cleanup()
Expand Down Expand Up @@ -176,33 +340,91 @@ def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessor,
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()

# Log PR curves
log_pr_curves(coco_evaluator, writer, global_step, iou_types)

stats = {}
if coco_evaluator is not None:
if 'bbox' in iou_types:
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in iou_types:
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()

# Log mAP metrics to TensorBoard
if writer is not None and dist_utils.is_main_process() and global_step is not None:
if 'bbox' in iou_types:
# COCO metrics: AP@IoU=0.5:0.95, AP@IoU=0.5, AP@IoU=0.75, etc.
bbox_stats = coco_evaluator.coco_eval['bbox'].stats
writer.add_scalar('mAP/bbox_AP50-95', bbox_stats[0], global_step)
writer.add_scalar('mAP/bbox_AP50', bbox_stats[1], global_step)
writer.add_scalar('mAP/bbox_AP75', bbox_stats[2], global_step)
writer.add_scalar('mAP/bbox_AP_small', bbox_stats[3], global_step)
writer.add_scalar('mAP/bbox_AP_medium', bbox_stats[4], global_step)
writer.add_scalar('mAP/bbox_AP_large', bbox_stats[5], global_step)
bbox_stats = coco_evaluator.coco_eval['bbox'].stats
# Add top-level metrics for quick overview
if writer is not None and dist_utils.is_main_process() and global_step is not None:
# Primary metrics at top level
writer.add_scalar('top-level-metrics/mAP_50_95', bbox_stats[0], global_step)

# Top-level recall metrics
writer.add_scalar('top-level-metrics/mAR_50_95', bbox_stats[8], global_step)

# Calculate and log F1 scores at top level
f1_50_95 = calculate_f1_score(bbox_stats[0], bbox_stats[8])

if f1_50_95 is not None:
writer.add_scalar('top-level-metrics/F1_50_95', f1_50_95, global_step)

# Continue with existing detailed metrics logging
if writer is not None and dist_utils.is_main_process() and global_step is not None:
# Average Precision metrics (indices 0-5)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_all_maxDets_100', bbox_stats[0], global_step)
writer.add_scalar('metrics-AP/IoU_0.50_area_all_maxDets_100', bbox_stats[1], global_step)
writer.add_scalar('metrics-AP/IoU_0.75_area_all_maxDets_100', bbox_stats[2], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_small_maxDets_100', bbox_stats[3], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_medium_maxDets_100', bbox_stats[4], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_large_maxDets_100', bbox_stats[5], global_step)
# Average Recall metrics (indices 6-11)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_1', bbox_stats[6], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_10', bbox_stats[7], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_100', bbox_stats[8], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_small_maxDets_100', bbox_stats[9], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_medium_maxDets_100', bbox_stats[10], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_large_maxDets_100', bbox_stats[11], global_step)

# Calculate and log F1 scores only when valid
# For IoU 0.50:0.95
f1_50_95 = calculate_f1_score(bbox_stats[0], bbox_stats[8])
if f1_50_95 is not None:
writer.add_scalar('metrics-F1/IoU_0.50-0.95_area_all_maxDets_100', f1_50_95, global_step)

# For IoU 0.50
f1_50 = calculate_f1_score(bbox_stats[1], bbox_stats[8])
if f1_50 is not None:
writer.add_scalar('metrics-F1/IoU_0.50_area_all_maxDets_100', f1_50, global_step)

# For IoU 0.75
f1_75 = calculate_f1_score(bbox_stats[2], bbox_stats[8])
if f1_75 is not None:
writer.add_scalar('metrics-F1/IoU_0.75_area_all_maxDets_100', f1_75, global_step)

# Small
f1_small = calculate_f1_score(bbox_stats[3], bbox_stats[9])
if f1_small is not None:
writer.add_scalar('metrics-F1/IoU_0.50-0.95_area_small_maxDets_100', f1_small, global_step)

# Medium
f1_medium = calculate_f1_score(bbox_stats[4], bbox_stats[10])
if f1_medium is not None:
writer.add_scalar('metrics-F1/IoU_0.50-0.95_area_medium_maxDets_100', f1_medium, global_step)

# Large
f1_large = calculate_f1_score(bbox_stats[5], bbox_stats[11])
if f1_large is not None:
writer.add_scalar('metrics-F1/IoU_0.50-0.95_area_large_maxDets_100', f1_large, global_step)

if 'segm' in iou_types:
segm_stats = coco_evaluator.coco_eval['segm'].stats
writer.add_scalar('mAP/segm_AP50-95', segm_stats[0], global_step)
writer.add_scalar('mAP/segm_AP50', segm_stats[1], global_step)
writer.add_scalar('mAP/segm_AP75', segm_stats[2], global_step)
writer.add_scalar('mAP/segm_AP_small', segm_stats[3], global_step)
writer.add_scalar('mAP/segm_AP_medium', segm_stats[4], global_step)
writer.add_scalar('mAP/segm_AP_large', segm_stats[5], global_step)
# Average Precision metrics (indices 0-5)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_all_maxDets_100', segm_stats[0], global_step)
writer.add_scalar('metrics-AP/IoU_0.50_area_all_maxDets_100', segm_stats[1], global_step)
writer.add_scalar('metrics-AP/IoU_0.75_area_all_maxDets_100', segm_stats[2], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_small_maxDets_100', segm_stats[3], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_medium_maxDets_100', segm_stats[4], global_step)
writer.add_scalar('metrics-AP/IoU_0.50-0.95_area_large_maxDets_100', segm_stats[5], global_step)
# Average Recall metrics (indices 6-11)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_1', segm_stats[6], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_10', segm_stats[7], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_all_maxDets_100', segm_stats[8], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_small_maxDets_100', segm_stats[9], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_medium_maxDets_100', segm_stats[10], global_step)
writer.add_scalar('metrics-AR/IoU_0.50-0.95_area_large_maxDets_100', segm_stats[11], global_step)

return stats, coco_evaluator
40 changes: 16 additions & 24 deletions src/deimkit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ def fit(
self._save_checkpoint(
epoch, eval_stats, self.output_dir / "best.pth"
)
# Add a prominent message for new best model
logger.info(
f"🏆 NEW BEST MODEL! Epoch {epoch} / mAP: {best_stats[k]}"
)
Expand All @@ -521,26 +520,10 @@ def fit(
self._save_checkpoint(
epoch, eval_stats, self.output_dir / "best.pth"
)
# Add a prominent message for new best model
logger.info(
f"🏆 NEW BEST MODEL! Epoch {epoch} / mAP: {best_stats[k]}"
)

# Get mAP value safely from eval_stats
# The first value in coco_eval_bbox is the AP@IoU=0.5:0.95 (primary metric)
coco_map = (
eval_stats.get("coco_eval_bbox", [0.0])[0]
if isinstance(eval_stats.get("coco_eval_bbox", [0.0]), list)
else 0.0
)

# Log mAP to tensorboard directly here as well
if writer is not None:
writer.add_scalar("metrics/mAP", coco_map, global_step)

# Also log best mAP so far
writer.add_scalar("metrics/best_mAP", top1, global_step)

logger.info(f"✅ Current best stats: {best_stats}")

# Save final checkpoint if not save_best_only
Expand Down Expand Up @@ -612,7 +595,9 @@ def _save_checkpoint(
torch.save(state, checkpoint_path)
logger.info(f"Checkpoint saved to {checkpoint_path}")

def load_checkpoint(self, checkpoint_path: Union[str, Path], strict: bool = False) -> None:
def load_checkpoint(
self, checkpoint_path: Union[str, Path], strict: bool = False
) -> None:
"""
Load a model checkpoint, handling potential image size differences.

Expand All @@ -624,9 +609,11 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], strict: bool = Fals
logger.info(f"Loading checkpoint from {checkpoint_path}")

# Load checkpoint
state = (torch.hub.load_state_dict_from_url(str(checkpoint_path), map_location="cpu")
if str(checkpoint_path).startswith("http")
else torch.load(checkpoint_path, map_location="cpu"))
state = (
torch.hub.load_state_dict_from_url(str(checkpoint_path), map_location="cpu")
if str(checkpoint_path).startswith("http")
else torch.load(checkpoint_path, map_location="cpu")
)

# Setup if not already done
if self.model is None:
Expand All @@ -637,12 +624,17 @@ def load_state_dict_with_mismatch(model, state_dict):
try:
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
logger.warning(f"Missing keys: {missing}\nUnexpected keys: {unexpected}")
logger.warning(
f"Missing keys: {missing}\nUnexpected keys: {unexpected}"
)
except RuntimeError as e:
logger.warning(f"Shape mismatch, loading compatible parameters only: {e}")
logger.warning(
f"Shape mismatch, loading compatible parameters only: {e}"
)
current_state = model.state_dict()
matched_state = {
k: v for k, v in state_dict.items()
k: v
for k, v in state_dict.items()
if k in current_state and current_state[k].shape == v.shape
}
model.load_state_dict(matched_state, strict=False)
Expand Down