Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
a25f7af
First draft for single detector WIP
sfmig Dec 10, 2025
1b7ca08
Add formatting
sfmig Dec 11, 2025
1621fac
Make formatting a static method
sfmig Dec 11, 2025
ab98740
Coerce types if possible in df validator for COCO export
sfmig Dec 11, 2025
b849e5c
Add example notebook for detector inference and proofreading
sfmig Dec 11, 2025
8f85888
Add note about change in validators
sfmig Dec 11, 2025
b2caf83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 11, 2025
edbc369
Add convenience wrapper to lightning module to format predictions
sfmig Dec 11, 2025
a51028a
Add inference module
sfmig Dec 12, 2025
3261727
Simplify get model state dict
sfmig Dec 12, 2025
39112a0
Expand to other torchvision detectors (WIP)
sfmig Dec 12, 2025
de5647c
Format notebook as proto example
sfmig Dec 12, 2025
edc789d
Remove detector config for now
sfmig Dec 12, 2025
d1659b1
Fix example
sfmig Dec 12, 2025
588b526
More doc fixes
sfmig Dec 12, 2025
14c3be9
Docstrings review
sfmig Dec 16, 2025
42ec3cf
Rename detector
sfmig Dec 16, 2025
27a0811
Add docstrings
sfmig Dec 16, 2025
c487132
Add some detector utils tests
sfmig Dec 16, 2025
e3cdd72
Add error messages to padding utils and expand tests
sfmig Dec 16, 2025
ea50248
Add test for formatting predictions from detector
sfmig Dec 16, 2025
3282601
Enhance centroid and shape conversion functions with detailed docstri…
sfmig Dec 18, 2025
96f9e15
Add tests for configure model pretrained
sfmig Dec 18, 2025
17a7235
Add basic config validation. Move model assigning of defaults to _con…
sfmig Dec 18, 2025
2de25f0
Add tests for config validation, loading from pretrained with num_cla…
sfmig Dec 18, 2025
93188da
Refactor _get_n_classes_in_detector function to a standalone function
sfmig Dec 18, 2025
68e2a5f
Refactor _get_n_classes_in_detector
sfmig Dec 18, 2025
4fe03bf
Fix
sfmig Dec 18, 2025
4a00cf7
Modify corner <-> centroid arrays conversions to attend to second dim…
sfmig Dec 18, 2025
ad4ef33
Add remaining tests
sfmig Dec 19, 2025
08d035d
Expand test_configure_model_from_checkpoint to support number of clas…
sfmig Dec 19, 2025
24868f5
Add tests for inference dataset
sfmig Dec 19, 2025
cc2c59a
Improve docstrings
sfmig Dec 19, 2025
b3637f8
Small edits to example (WIP)
sfmig Dec 19, 2025
4010c1c
Make default transforms actual default for InferenceImageDataset
sfmig Dec 22, 2025
39ca320
Convert torch.tensor predictions to numpy array in format_predictions…
sfmig Dec 22, 2025
d7acb46
Change docs links
sfmig Dec 22, 2025
bc02eec
Review dataset docstring
sfmig Dec 22, 2025
31cc451
Fix lunks
sfmig Dec 22, 2025
6e05b4b
Add examples in docstring
sfmig Dec 22, 2025
85eef8a
Docstring draft
sfmig Dec 22, 2025
511add7
More edits to docstring
sfmig Dec 22, 2025
7581762
More edits
sfmig Dec 22, 2025
1cd2b36
Add methods table in API docs for own-methods only (not inherited)
sfmig Dec 22, 2025
43fd40a
Review model docstrings, add links and Raises
sfmig Dec 23, 2025
718d409
Review dataset docstrings
sfmig Dec 23, 2025
ec8ff54
Expand config tests to ensure valid configs work
sfmig Dec 23, 2025
9c1e5d6
Expand list of model_kwargs to check for misspellings (torchvision ig…
sfmig Dec 23, 2025
a7ce3d2
Add types for clarity in utils. Edit docstring
sfmig Dec 23, 2025
0147f7d
Remove lightning docs fixes that are no longer required
sfmig Dec 23, 2025
23e4dc9
Verify proto example runs with latest implementation
sfmig Dec 23, 2025
0bec87f
Be more explicit about the loaded weights
sfmig Dec 23, 2025
f26453b
Example with public dataset
sfmig Dec 23, 2025
63eef64
Fix section title. Set accelerator to cpu
sfmig Dec 23, 2025
fb77b29
Add collate fn for detector
sfmig Dec 23, 2025
9dd24c1
Remove African wildlife dataset
sfmig Dec 23, 2025
de20699
Add a minimal test for detector collate_fn
sfmig Dec 23, 2025
d70d95a
Clarification in minimal test
sfmig Dec 23, 2025
208432d
Add return type
sfmig Dec 23, 2025
3b9495b
Merge branch 'main' into smg/faster-rcnn-pl
sfmig Jan 19, 2026
c3736eb
Merge branch 'main' into smg/faster-rcnn-pl
sfmig Feb 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/source/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
{% if objname != 'ValidDataset' %}:members:{% endif %}
{% if objname != 'ValidDataset' %}:inherited-members:{% endif %}
{% if objname == 'ValidBboxAnnotationsDataFrame' %}:exclude-members: Config{% endif %}
:members:
:show-inheritance:

{% block methods %}
{% set ns = namespace(has_public_methods=false) %}
Expand All @@ -24,11 +23,13 @@
.. autosummary::
{% for item in methods %}
{% if not item.startswith('_') %}
~{{ name }}.{{ item }}
{{ item|is_own_method(name, module) }}
{% endif %}
{%- endfor %}
{% endif %}
{% endblock %}

.. rubric:: {{ _('Details') }}

.. minigallery:: {{ module }}.{{ objname }}
:add-heading: Examples using ``{{ objname }}``
49 changes: 48 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Sphinx configuration for ethology documentation."""

import inspect
import os
import sys
from importlib import import_module
from importlib.metadata import version as get_version

from jinja2.filters import FILTERS
from sphinx_gallery import sorting

# Used when building API docs, put the dependencies
Expand Down Expand Up @@ -68,6 +71,7 @@
# Automatically add anchors to markdown headings
myst_heading_anchors = 4

# -------- Autosummary
# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]

Expand All @@ -76,6 +80,47 @@
autosummary_generate_overwrite = False
autodoc_default_options = {"show-inheritance": True} # applies to all classes


def is_own_method(method_name, obj, module_name):
"""Check if a method is defined in the class itself (not inherited).

Returns the method reference string if it's defined in the class,
empty string otherwise.
"""
module = import_module(module_name)
if hasattr(module, "__all__") and obj not in module.__all__:
return ""

cls = getattr(module, obj)
if not inspect.isclass(cls):
return ""

# Check if method is defined in this class (not inherited)
if hasattr(cls, method_name):
# Check if it's in the class's __dict__ (defined in this class)
if method_name in cls.__dict__:
return f"~{obj}.{method_name}"
# Or check using inspect to see if it's defined in this class
try:
method = getattr(cls, method_name)
if inspect.ismethod(method) or inspect.isfunction(method):
# Check if the method's defining class is this class
if hasattr(method, "__qualname__"):
qualname_parts = method.__qualname__.split(".")
if len(qualname_parts) >= 2 and qualname_parts[-2] == obj:
return f"~{obj}.{method_name}"
# Fallback: check if it's in __dict__
if method_name in cls.__dict__:
return f"~{obj}.{method_name}"
except (AttributeError, TypeError):
pass

return ""


FILTERS["is_own_method"] = is_own_method

# -------------
# Prefix section labels with the document name
autosectionlabel_prefix_document = True

Expand Down Expand Up @@ -197,6 +242,9 @@
"https://python-jsonschema.readthedocs.io/en/stable/",
None,
),
"torch": ("https://pytorch.org/docs/stable/", None),
"torchvision": ("https://pytorch.org/vision/stable/", None),
"lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
}


Expand All @@ -221,7 +269,6 @@


# sphinx-gallery configuration

sphinx_gallery_conf = {
"examples_dirs": ["../../examples"],
"within_subsection_order": sorting.ExplicitOrder(
Expand Down
143 changes: 143 additions & 0 deletions ethology/datasets/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Datasets and related utilities for inference without ground-truth."""

from collections.abc import Callable
from pathlib import Path

import torch
import torchvision.transforms.v2 as transforms
from PIL import Image
from torch.utils.data import Dataset


class InferenceImageDataset(Dataset):
"""A simple dataset for images with no ground-truth annotations.

The image files are sorted alphabetically. The annotations dictionary
returned by ``__getitem__`` is always empty to maintain a consistent
interface with training datasets.

Parameters
----------
images_dir
Path to the root directory containing the images.
file_pattern
Glob pattern to match image filenames (e.g., "*.png", "*.jpg").
transforms
Transforms to apply to the images. If `None` (default), the
transforms from :func:`get_default_inference_transforms` are used.

Attributes
----------
images_dir : pathlib.Path
Path to the root directory containing the images.
transforms : torchvision.transforms.v2.Compose
Transforms to apply to the images.
image_files : list[pathlib.Path]
List of paths to each of the image files, sorted
alphabetically.

See Also
--------
get_default_inference_transforms : Returns default transforms for
inference.

Examples
--------
Create a dataset from 100 ``.png`` files in the ``/path/to/images``
directory:

>>> from ethology.datasets.inference import InferenceImageDataset
>>> dataset = InferenceImageDataset(
... images_dir="/path/to/images",
... file_pattern="*.png",
... )
>>> len(dataset)
100

"""

def __init__(
self,
images_dir: Path | str,
file_pattern: str,
transforms: transforms.Compose | None = None,
):
"""Initialise dataset."""
self.images_dir = Path(images_dir)
self.transforms = (
transforms
if transforms is not None
else get_default_inference_transforms()
)
self.image_files = sorted(self.images_dir.glob(file_pattern))

def __len__(self) -> int:
"""Return the number of images in the dataset."""
return len(self.image_files)

def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]:
"""Return the image and an empty annotations dictionary.

Parameters
----------
idx : int
Index of the image to retrieve.

Returns
-------
tuple[torch.Tensor, dict]
A tuple containing the image as a tensor and an empty
annotations dictionary.

"""
# Open requested image
img_path = Path(self.images_dir) / self.image_files[idx]
image = Image.open(img_path).convert("RGB")

# If transforms are specified, apply to the image
if self.transforms:
image = self.transforms(image)
return image, {}


def get_default_inference_transforms() -> transforms.Compose:
"""Return the default transforms for inference.

Transforms the input image to a tensor and scales the pixel
values from ``[0, 255]`` (uint8) to ``[0, 1]`` (float32).

Returns
-------
torchvision.transforms.v2.Compose
The default transforms for inference.

"""
return transforms.Compose(
[
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
]
)


def get_detector_collate_fn() -> Callable:
"""Return collate function for a detector.

It supports images and annotations of different sizes.
A collate function takes a list of samples from a dataset
and batches them for the model to process it. Torch detectors
expect a list/tuple of images and annotations, since they can
be of different sizes across the dataset.

Returns
-------
Callable
A collate function for detector models.

"""

def collate_fn(dataset_samples):
"""Return a tuple of tuples: (images_tuple, annots_tuple)."""
return tuple(zip(*dataset_samples, strict=True))

return collate_fn
Loading
Loading