From cffc5751fb66d74c9c630f7e83e67dd4d44a33f6 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Fri, 10 Oct 2025 17:26:08 +0100 Subject: [PATCH] Draft --- ethology/__init__.py | 6 +- ethology/detectors/datasets.py | 273 ++++++++++++++++++++ examples/01_annotations_as_torch_dataset.py | 245 ++++++++++++++++++ 3 files changed, 523 insertions(+), 1 deletion(-) create mode 100644 ethology/detectors/datasets.py create mode 100644 examples/01_annotations_as_torch_dataset.py diff --git a/ethology/__init__.py b/ethology/__init__.py index 3e47301a..1129ea2d 100644 --- a/ethology/__init__.py +++ b/ethology/__init__.py @@ -1,10 +1,14 @@ from importlib.metadata import PackageNotFoundError, version - import xarray as xr +from pathlib import Path # Set xarray attributes collapsed by default xr.set_options(display_expand_attrs=False) +# Set cache directory for ethology package +ETHOLOGY_CACHE_DIR = Path.home() / ".ethology" +ETHOLOGY_CACHE_DIR.mkdir(parents=True, exist_ok=True) + try: __version__ = version("ethology") except PackageNotFoundError: diff --git a/ethology/detectors/datasets.py b/ethology/detectors/datasets.py new file mode 100644 index 00000000..db180af6 --- /dev/null +++ b/ethology/detectors/datasets.py @@ -0,0 +1,273 @@ +"""Utilities for creating and manipulating datasets for detection.""" + +from datetime import datetime +from pathlib import Path +from typing import Any + +import torch +import torchvision.transforms.v2 as transforms +import xarray as xr +from loguru import logger +from torch.utils.data import random_split +from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2 + +from ethology import ETHOLOGY_CACHE_DIR +from ethology.io.annotations import save_bboxes +from ethology.io.annotations.validate import ValidCOCO, _check_input + + +def annotations_dataset_to_torch_dataset( + ds: xr.Dataset, + images_directory: Path | str | None = None, + transforms: transforms.Compose | None = None, + out_filepath: Path | str | None = None, + kwargs: dict[str, Any] | None = None, +) -> CocoDetection: + """Convert an bounding boxes annotations dataset to a torch dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to convert. + images_directory : Path | str | None, optional + The path to the images directory. + transforms : torchvision.transforms.v2.Compose | None, optional + The transforms to apply to the dataset. + out_filepath : Path | str | None, optional + The path to the output COCO file. + kwargs : dict[str, Any] | None, optional + Additional keyword arguments to pass to the torch dataset constructor. + + Returns + ------- + CocoDetection + The converted torch dataset. + + """ + # Export xarray dataset to COCO file + timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") + if out_filepath is None: + out_filepath = ETHOLOGY_CACHE_DIR / f"tmp_out_{timestamp}.json" + else: + suffix = Path(out_filepath).suffix + path_without_suffix = Path(out_filepath).with_suffix("") + out_filepath = Path(f"{path_without_suffix}_{timestamp}.{suffix}") + + out_file = save_bboxes.to_COCO_file(ds, out_filepath) + logger.info(f"Exported temporary COCO file to {out_file}") + + # Get images directory + # if not provided, check the dataset attributes + if images_directory is None: + images_directory = ds.attrs.get("images_directories", None) + if isinstance(images_directory, list) and len(images_directory) > 0: + images_directory = images_directory[0] + logger.warning( + f"Using first images directory only: {images_directory}" + ) # TODO: loop thru them? + elif images_directory is None: + raise KeyError( + "`images_directories` is not set. " + "Please provide `images_directory` as an input or " + "add it to the dataset attributes." + ) + + # Create torch dataset + return CocoDetection( + root=images_directory, + annFile=out_file, + transforms=transforms, + **kwargs if kwargs is not None else {}, + ) + + +def torch_dataset_to_annotations_dataset( + torch_dataset: torch.utils.data.Dataset, +) -> xr.Dataset: + """Convert a torch dataset to an annotations dataset.""" + pass + + +@_check_input(validator=ValidCOCO) +def torch_dataset_from_COCO_file( + annotations_file: str | Path, + images_directory: str | Path, + kwargs: dict[str, Any] | None = None, +) -> CocoDetection: + """Create a COCO dataset for object detection. + + Note: transforms are applied to the full dataset. If the dataset + is later split, all splits will have the same transforms. + + Parameters + ---------- + annotations_file : str | Path + The path to the input COCO file. + images_directory : str | Path + The path to the images directory. + kwargs : dict[str, Any] | None, optional + Additional keyword arguments to pass to the torch dataset constructor. + + Returns + ------- + torch.utils.data.Dataset + The converted torch dataset. + + """ + dataset_coco = CocoDetection( + root=str(images_directory), + annFile=str(annotations_file), + **kwargs if kwargs is not None else {}, + ) + + # wrap dataset for transforms v2 + dataset_transformed = wrap_dataset_for_transforms_v2(dataset_coco) + + return dataset_transformed + + +def split_torch_dataset( + dataset: torch.utils.data.Dataset, + train_val_test_fractions: list[float], + seed: int = 42, +) -> tuple[ + torch.utils.data.Dataset, + torch.utils.data.Dataset, + torch.utils.data.Dataset, +]: + """Split a torchdataset into train, validation, and test sets. + + Note that transforms are already applied to the input dataset. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The torch dataset to split. + train_val_test_fractions : list[float] + The fractions of the dataset to allocate to the train, validation, + and test sets. + seed : int, optional + The seed to use for the random number generator. Default is 42. + + Returns + ------- + tuple[torch.utils.data.Dataset] + The train, validation, and test sets. + + """ + # Check that the fractions sum to 1 + if sum(train_val_test_fractions) != 1: + raise ValueError("The split fractions must sum to 1.") + + # Log transforms applied to the dataset + logger.info( + f"Dataset transforms (propagated to all splits): {dataset.transforms}" + ) + + # Create random number generator for reproducibility if seed is provided + rng_split = None + if seed is not None: + rng_split = torch.Generator().manual_seed(seed) + + # Split dataset + train_dataset, test_dataset, val_dataset = random_split( + dataset, + train_val_test_fractions, + generator=rng_split, + ) + + # Print number of samples in each split + logger.info(f"Seed: {seed}") + logger.info(f"Number of training samples: {len(train_dataset)}") + logger.info(f"Number of validation samples: {len(val_dataset)}") + logger.info(f"Number of test samples: {len(test_dataset)}") + + return train_dataset, test_dataset, val_dataset + + +# def torch_dataset_to_annotations_dataset( +# torch_dataset: torch.utils.data.Dataset, +# ) -> xr.Dataset: +# """Convert a torch dataset to an annotations dataset.""" +# # Read list of rows +# list_rows = [annot for _img, annot in torch_dataset] + +# # --------- +# # Read list of rows as a dataframe +# df = pd.DataFrame(list_rows) + +# # Sort annotations by image_filename +# df = df.sort_values(by=["image_filename"]) + +# # Drop duplicates and reindex +# # The resulting axis is labeled 0,1,…,n-1. +# df = df.drop_duplicates( +# subset=[col for col in df.columns if col != "annotation_id"], +# ignore_index=True, +# inplace=False, +# ) + +# # Cast bbox coordinates and shape as floats +# for col in ["x_min", "y_min", "width", "height"]: +# df[col] = df[col].astype(np.float64) + +# # Set the index name to "annotation_id" +# df = df.set_index("annotation_id") +# # --------- + +# # Get maps to set as dataset attributes +# map_image_id_to_filename, map_category_to_str = ( +# load_bboxes._get_map_attributes_from_df(df) +# ) + +# # Convert dataframe to xarray dataset +# ds = load_bboxes._df_to_xarray_ds(df) + +# # Add attributes to the xarray dataset +# ds.attrs = { +# # "annotation_files": file_paths, +# "annotation_format": 'torch-dataset', +# "map_category_to_str": map_category_to_str, +# "map_image_id_to_filename": map_image_id_to_filename, +# } +# # ----------- + +# # Add image dir as metadata +# root = _find_nested_root(torch_dataset) +# if root: +# ds.attrs["images_directories"] = root + + +# return ds + + +# def _find_nested_root( +# dataset: torch.utils.data.Dataset +# ) -> str | Path | None: +# """Find root of a possibly nested dataset. + +# Parameters +# ---------- +# dataset : torch.utils.data.Dataset +# The dataset to check. It may be the result of multiple +# splits, and therefore be nested. + +# Returns +# ------- +# str or Path or None +# The nested root value for the dataset, or None if not found + +# """ +# current = dataset + +# # Check current level +# if hasattr(current, "root"): +# return current + +# # Check through dataset levels +# while hasattr(current, "dataset"): +# current = current.dataset +# if hasattr(current, "root"): +# return current.root + +# return None diff --git a/examples/01_annotations_as_torch_dataset.py b/examples/01_annotations_as_torch_dataset.py new file mode 100644 index 00000000..248ef556 --- /dev/null +++ b/examples/01_annotations_as_torch_dataset.py @@ -0,0 +1,245 @@ +"""Convert ``ethology`` annotations to PyTorch dataset +======================================================== + +Load bounding box annotations as an ``ethology`` dataset, select a subset of +categories and convert to a +`torch COCO dataset `_. +""" + + +# %% +# Imports +# ------- + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pooch +import torch +import torchvision.transforms.v2.functional as F +from torchvision.utils import draw_bounding_boxes + +from ethology.detectors.datasets import annotations_dataset_to_torch_dataset +from ethology.io.annotations import load_bboxes + +# For interactive plots: install ipympl with `pip install ipympl` and uncomment +# the following line in your notebook +# %matplotlib widget + + +# %% +# Download dataset +# ------------------ +# +# For this example, we will use the dataset from the +# `UAS Imagery of Migratory Waterfowl at New Mexico Wildlife Refuges `_. +# This dataset is part of the `Drones For Ducks project +# `_ +# that aims to develop an efficient method to count and identify species of +# migratory waterfowl at wildlife refuges across New Mexico. +# +# The dataset is made up of a set of drone images and corresponding +# bounding box annotations. Annotations are provided by both expert +# annotators and volunteers. +# +# Since the dataset is not very large, we can download it as a zip file +# directly from the URL provided in the dataset webpage. +# We use the `pooch `_ library +# to download it to the ``.ethology`` cache directory. + + +# Source of the dataset +data_source = { + "url": "https://storage.googleapis.com/public-datasets-lila/uas-imagery-of-migratory-waterfowl/uas-imagery-of-migratory-waterfowl.20240220.zip", + "hash": "c5b8dfc5a87ef625770ac8f22335dc9eb8a67688b610490a029dae81815a9896", +} + +# Define cache directory +ethology_cache = Path.home() / ".ethology" +ethology_cache.mkdir(exist_ok=True) + +# Download the dataset to the cache directory +extracted_files = pooch.retrieve( + url=data_source["url"], + known_hash=data_source["hash"], + fname="waterfowl_dataset.zip", + path=ethology_cache, + processor=pooch.Unzip(extract_dir=ethology_cache), +) + + +# %% +# For this example, we will focus on the annotations labelled by the experts. + +data_dir = ethology_cache / "uas-imagery-of-migratory-waterfowl" +experts_dir = data_dir / "experts" + +annotations_file = experts_dir / "20230331_dronesforducks_expert_refined.json" +images_dir = experts_dir / "images" + + +# %% +# Load annotations as `ethology` dataset +# -------------------------------------- + +ds = load_bboxes.from_files( + annotations_file, images_dirs=images_dir, format="COCO" +) + +print(ds) +print(ds.sizes) + +# %% +# Transform image filenames +# ------------------------- + +# Image filenames in input file are .JPG but files are .jpg +# Change image filenames dict to .jpg +map_image_id_to_filename = { + k: v.replace(".JPG", ".jpg") + for k, v in ds.map_image_id_to_filename.items() +} +ds.attrs["map_image_id_to_filename"] = map_image_id_to_filename + + +# %% +# Count annotations per category +# ------------------------------- + +# TODO: use Counter and .most_common() +list_category_counts = [ + (ky, val, (ds.category == ky).sum().item()) + for ky, val in ds.map_category_to_str.items() +] + +# Sort by decreasing count +list_category_counts.sort(key=lambda x: x[2], reverse=True) + + +# %% +# Select a subset dataset +# ----------------------- +# Make a new dataset with only the bottom/top 3 categories + +# Compute the categories to keep +n_categories = 2 +categories_to_keep = [x[0] for x in list_category_counts[:n_categories]] +print(f"Categories to keep: {categories_to_keep}") + +# Compute categories mask array +# True where categories are in the set to keep, False otherwise +categories_mask = ds.category.isin(categories_to_keep) # dim: image_id, id + +ds_subset = ds.where(categories_mask, drop=True) + +# inspect +print(f"ds_subset unique categories: {np.unique(ds_subset.category.values)}") +print(f"ds_subset.sizes: {ds_subset.sizes}") # note reduced dimensions +print(f"ds_subset.image_shape shape: {ds_subset.image_shape.shape}") + +# %% +# Note that due to the underlying broadcasting in the ``where`` operation, +# the image_shape array now has ``image_id``, ``space``, and ``id`` dimensions +# and the ``category`` array is now a float. For clarity we go back to the +# ``ethology`` +# convention and make the ``category`` array and integer one, with -1 for empty +# values, and the ``image_shape`` array an integer one with only ``image_id`` +# and ``space`` dimensions. + +ds_subset["category"] = ds_subset.category.fillna(-1).astype(int) +ds_subset["image_shape"] = ds.image_shape +# this assignment takes only the (image_id, space) coordinates from +# ds.image_shape that are also present in ds_subset + + +print(ds_subset) +print("---------") +print(f"ds_subset.image_shape shape: {ds_subset.image_shape.shape}") + + +# %% +# Convert dataset to torch dataset +# ----------------------------------- + +dataset_torch = annotations_dataset_to_torch_dataset(ds_subset) + + +# %% +# Sample from the torch dataset and convert bbox format +# ----------------------------------------------------- + +# get one image and its annotations +sample_idx = 2 +img, annot = dataset_torch[sample_idx] + +# annot is a list of bboxes dictionaries +# coords are in XYWH format +bboxes_tensor_xywh = torch.as_tensor([ann["bbox"] for ann in annot]) +print(f"Bbox in XYWH format: {bboxes_tensor_xywh[0, :]}") + +# convert bbox format from XYWH to XYXY +bboxes_tensor_xyxy = F.convert_bounding_box_format( + bboxes_tensor_xywh, + old_format="XYWH", + new_format="XYXY", +) +print(f"Bbox in XYXY format: {bboxes_tensor_xyxy[0, :]}") + + +# %% +# Visualize selected sample using torchvision ``draw_bounding_boxes`` +# -------------------------------------------------------------------- +# From https://docs.pytorch.org/vision/0.21/auto_examples/others/plot_visualization_utils.html + + +# map category ID to color +cmap = plt.cm.tab10 +map_category_id_to_color_ints = { + i: tuple((np.array(cmap(i)[:3]) * 255).astype(int)) + for i in np.unique(ds_subset.category.values) + if i != -1 +} +map_category_id_to_color_floats = { + i: tuple(np.array(cmap(i)[:3])) + for i in np.unique(ds_subset.category.values) + if i != -1 +} + +# list of categories per annotation in image +list_category_ids_in_image = [ann["category_id"] for ann in annot] + +# create image with boxes +img_with_boxes = draw_bounding_boxes( + F.pil_to_tensor(img), + bboxes_tensor_xyxy, + colors=[ + map_category_id_to_color_ints[x] for x in list_category_ids_in_image + ], + width=15, +) + +# plot +fig, ax = plt.subplots() +ax.imshow(img_with_boxes.permute(1, 2, 0)) +ax.set_xlabel("x (pixels)") +ax.set_ylabel("y (pixels)") + + +# add legend +legend_elements = [ + plt.Line2D([0], [0], color=c) + for c in map_category_id_to_color_floats.values() +] +legend_labels = [ + ds_subset.map_category_to_str[x] + for x in np.unique(ds_subset.category.values) + if x != -1 +] +plt.legend( + legend_elements, + legend_labels, + bbox_to_anchor=(1.05, 1), + loc="upper left", +) +plt.tight_layout()