Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ recursive-exclude examples *
recursive-include docs *

# Include json schemas
recursive-include ethology/io/annotations/json_schemas/schemas *.json
recursive-include ethology/io/annotations/json_schemas/schemas *.md
recursive-include ethology/validators/json_schemas/schemas *.json
recursive-include ethology/validators/json_schemas/schemas *.md
8 changes: 3 additions & 5 deletions docs/source/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}
:members:
:show-inheritance:
:inherited-members:
:exclude-members: Config

{% if objname != 'ValidDataset' %}:members:{% endif %}
{% if objname != 'ValidDataset' %}:inherited-members:{% endif %}
{% if objname == 'ValidBboxAnnotationsDataFrame' %}:exclude-members: Config{% endif %}

{% block methods %}
{% set ns = namespace(has_public_methods=false) %}
Expand Down
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
# Automatically generate stub pages for API
autosummary_generate = True
autosummary_generate_overwrite = False
autodoc_default_flags = ["members", "inherited-members"]
autodoc_default_options = {"show-inheritance": True} # applies to all classes

# Prefix section labels with the document name
autosectionlabel_prefix_document = True
Expand Down Expand Up @@ -182,6 +182,10 @@
"pandera": ("https://pandera.readthedocs.io/en/stable/", None),
"movement": ("https://movement.neuroinformatics.dev/latest/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"jsonschema": (
"https://python-jsonschema.readthedocs.io/en/stable/",
None,
),
}


Expand Down
26 changes: 14 additions & 12 deletions ethology/io/annotations/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
import xarray as xr
from pandera.typing.pandas import DataFrame

from ethology.io.annotations.validate import (
ValidBboxesDataFrame,
ValidBboxesDataset,
from ethology.validators.annotations import (
ValidBboxAnnotationsDataFrame,
ValidBboxAnnotationsDataset,
ValidCOCO,
ValidVIA,
_check_output,
)
from ethology.validators.utils import _check_output


@_check_output(ValidBboxesDataset)
@_check_output(ValidBboxAnnotationsDataset)
def from_files(
file_paths: Path | str | list[Path | str],
format: Literal["VIA", "COCO"],
Expand Down Expand Up @@ -138,7 +138,7 @@ def from_files(


def _get_map_attributes_from_df(
df: DataFrame[ValidBboxesDataFrame],
df: DataFrame[ValidBboxAnnotationsDataFrame],
) -> tuple[dict, dict]:
"""Get the map attributes from the dataframe.

Expand Down Expand Up @@ -179,7 +179,7 @@ def _get_map_attributes_from_df(
@pa.check_types
def _df_from_multiple_files(
list_filepaths: list[Path | str], format: Literal["VIA", "COCO"]
) -> DataFrame[ValidBboxesDataFrame]:
) -> DataFrame[ValidBboxAnnotationsDataFrame]:
"""Read annotations from multiple files as a valid intermediate dataframe.

Parameters
Expand Down Expand Up @@ -242,7 +242,7 @@ def _df_from_multiple_files(
@pa.check_types
def _df_from_single_file(
file_path: Path | str, format: Literal["VIA", "COCO"]
) -> DataFrame[ValidBboxesDataFrame]:
) -> DataFrame[ValidBboxAnnotationsDataFrame]:
"""Read annotations from a single file as a valid intermediate dataframe.

Parameters
Expand Down Expand Up @@ -374,7 +374,7 @@ def _df_rows_from_valid_VIA_file(file_path: Path) -> list[dict]:

else:
supercategory, category, category_id = (
ValidBboxesDataFrame.get_empty_values()[key]
ValidBboxAnnotationsDataFrame.get_empty_values()[key]
for key in ["supercategory", "category", "category_id"]
)

Expand Down Expand Up @@ -428,7 +428,7 @@ def _get_image_shape_attr_as_integer(
ValidBboxesDataFrame.get_empty_values().

"""
default_value = ValidBboxesDataFrame.get_empty_values()[
default_value = ValidBboxAnnotationsDataFrame.get_empty_values()[
f"image_{attr_name}"
]
try:
Expand Down Expand Up @@ -557,7 +557,9 @@ def _df_rows_from_valid_COCO_file(file_path: Path) -> list[dict]:


@pa.check_types
def _df_to_xarray_ds(df: DataFrame[ValidBboxesDataFrame]) -> xr.Dataset:
def _df_to_xarray_ds(
df: DataFrame[ValidBboxAnnotationsDataFrame],
) -> xr.Dataset:
"""Convert a bounding box annotations dataframe to an xarray dataset.

Parameters
Expand Down Expand Up @@ -585,7 +587,7 @@ def _df_to_xarray_ds(df: DataFrame[ValidBboxesDataFrame]) -> xr.Dataset:

"""
# Drop columns if all values in that column are empty
default_values = ValidBboxesDataFrame.get_empty_values()
default_values = ValidBboxAnnotationsDataFrame.get_empty_values()
list_empty_cols = [
col for col in default_values if all(df[col] == default_values[col])
]
Expand Down
27 changes: 14 additions & 13 deletions ethology/io/annotations/save_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@
import xarray as xr
from pandera.typing.pandas import DataFrame

from ethology.io.annotations.validate import (
ValidBboxesDataFrameCOCO,
ValidBboxesDataset,
from ethology.validators.annotations import (
ValidBboxAnnotationsCOCO,
ValidBboxAnnotationsDataset,
ValidCOCO,
_check_input,
_check_output,
)
from ethology.validators.utils import _check_input, _check_output


@_check_input(validator=ValidBboxesDataset)
@_check_output(validator=ValidCOCO) # check output is ethology importable
@_check_input(validator=ValidBboxAnnotationsDataset)
@_check_output(validator=ValidCOCO) # check output is ethology-importable
def to_COCO_file(dataset: xr.Dataset, output_filepath: str | Path):
"""Save an ``ethology`` bounding box annotations dataset to a COCO file.

Expand Down Expand Up @@ -56,11 +55,11 @@ def to_COCO_file(dataset: xr.Dataset, output_filepath: str | Path):
return output_filepath


@_check_input(validator=ValidBboxesDataset)
@_check_input(validator=ValidBboxAnnotationsDataset)
@pa.check_types
def _to_COCO_exportable_df(
ds: xr.Dataset,
) -> DataFrame[ValidBboxesDataFrameCOCO]:
) -> DataFrame[ValidBboxAnnotationsCOCO]:
"""Convert dataset of bounding boxes annotations to a COCO-exportable df.

The returned dataframe is validated using ValidBBoxesDataFrameCOCO.
Expand Down Expand Up @@ -98,7 +97,7 @@ def _to_COCO_exportable_df(
return df[cols_to_select]


@_check_input(validator=ValidBboxesDataset)
@_check_input(validator=ValidBboxAnnotationsDataset)
def _get_raw_df_from_ds(ds: xr.Dataset) -> pd.DataFrame:
"""Get preliminary dataframe from a dataset of bounding boxes annotations.

Expand Down Expand Up @@ -164,7 +163,7 @@ def _get_raw_df_from_ds(ds: xr.Dataset) -> pd.DataFrame:
@pa.check_types
def _add_COCO_data_to_df(
df: pd.DataFrame, ds_attrs: dict
) -> DataFrame[ValidBboxesDataFrameCOCO]:
) -> DataFrame[ValidBboxAnnotationsCOCO]:
"""Add COCO-required data to preliminary dataframe.

The input dataframe is obtained from a dataset of bounding boxes
Expand Down Expand Up @@ -266,7 +265,9 @@ def _add_COCO_data_to_df(


@pa.check_types
def _create_COCO_dict(df: DataFrame[ValidBboxesDataFrameCOCO]) -> dict:
def _create_COCO_dict(
df: DataFrame[ValidBboxAnnotationsCOCO],
) -> dict:
"""Extract COCO dictionary from a COCO-exportable dataframe.

Parameters
Expand All @@ -282,7 +283,7 @@ def _create_COCO_dict(df: DataFrame[ValidBboxesDataFrameCOCO]) -> dict:
"""
COCO_dict: dict[str, Any] = {}
map_columns_to_COCO_fields = (
ValidBboxesDataFrameCOCO.map_df_columns_to_COCO_fields()
ValidBboxAnnotationsCOCO.map_df_columns_to_COCO_fields()
)
for sections in ["images", "categories", "annotations"]:
# Extract and rename required columns for this section
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
"""Validators for annotation files and datasets."""

import json
from collections.abc import Callable
from functools import wraps
from pathlib import Path

import pandas as pd
import pandera.pandas as pa
import xarray as xr
from attrs import define, field
from pandera.typing import Index

from ethology.io.annotations.json_schemas.utils import (
from ethology.validators.json_schemas.utils import (
_check_file_is_json,
_check_file_matches_schema,
_check_required_keys_in_dict,
_get_default_schema,
)
from ethology.validators.utils import ValidDataset


@define
Expand Down Expand Up @@ -227,25 +225,39 @@ def _file_contains_unique_image_IDs(self, attribute, value):


@define
class ValidBboxesDataset:
class ValidBboxAnnotationsDataset(ValidDataset):
"""Class for valid ``ethology`` bounding box annotations datasets.

It checks that the input dataset has:
This class validates that the input dataset:

- is an xarray Dataset,
- has ``image_id``, ``space``, ``id`` as dimensions,
- has ``position`` and ``shape`` as data variables,
- both data variables span at least the dimensions ``image_id``,
``space`` and ``id``.

- ``image_id``, ``space``, ``id`` as dimensions
- ``position`` and ``shape`` as data variables

Attributes
----------
dataset : xarray.Dataset
The xarray dataset to validate.
required_dims : set[str]
The set of required dimension names: ``image_id``, ``space`` and
``id``.
required_data_vars : dict[str, set]
A dictionary mapping data variable names to their required minimum
dimensions:

- ``position`` maps to ``image_id``, ``space`` and ``id``,
- ``shape`` maps to ``image_id``, ``space`` and ``id``.

Raises
------
TypeError
If the input is not an xarray Dataset.
ValueError
If the dataset is missing required data variables or dimensions.
If the dataset is missing required data variables or dimensions,
or if any required dimensions are missing for any data variable.

Notes
-----
Expand All @@ -254,46 +266,21 @@ class ValidBboxesDataset:

"""

dataset: xr.Dataset = field()

# Minimum requirements for annotations datasets holding bboxes
# Minimum requirements for a bbox dataset holding detections
required_dims: set = field(
default={"image_id", "space", "id"},
init=False,
)
required_data_vars: set = field(
default={"position", "shape"},
required_data_vars: dict = field(
default={
"position": {"image_id", "space", "id"},
"shape": {"image_id", "space", "id"},
},
init=False,
)

@dataset.validator
def _check_dataset_type(self, attribute, value):
"""Ensure the input is an xarray Dataset."""
if not isinstance(value, xr.Dataset):
raise TypeError(
f"Expected an xarray Dataset, but got {type(value)}."
)

@dataset.validator
def _check_required_data_variables(self, attribute, value):
"""Ensure the dataset has all required data variables."""
missing_vars = self.required_data_vars - set(value.data_vars)
if missing_vars:
raise ValueError(
f"Missing required data variables: {sorted(missing_vars)}"
)

@dataset.validator
def _check_required_dimensions(self, attribute, value):
"""Ensure the dataset has all required dimensions."""
missing_dims = self.required_dims - set(value.dims)
if missing_dims:
raise ValueError(
f"Missing required dimensions: {sorted(missing_dims)}"
)


class ValidBboxesDataFrame(pa.DataFrameModel):
class ValidBboxAnnotationsDataFrame(pa.DataFrameModel):
"""Class for valid bounding boxes intermediate dataframes.

We use this dataframe internally as an intermediate step in the process of
Expand Down Expand Up @@ -422,7 +409,7 @@ def get_empty_values() -> dict:
}


class ValidBboxesDataFrameCOCO(pa.DataFrameModel):
class ValidBboxAnnotationsCOCO(pa.DataFrameModel):
"""Class for COCO-exportable bounding box annotations dataframes.

The validation checks the required columns exist and their types are
Expand Down Expand Up @@ -573,38 +560,3 @@ def check_idx_and_annotation_id(cls, df: pd.DataFrame) -> bool:

"""
return all(df.index == df["annotation_id"])


def _check_output(validator: type):
"""Return a decorator that validates the output of a function."""

def decorator(function: Callable) -> Callable:
@wraps(function) # to preserve function metadata
def wrapper(*args, **kwargs):
result = function(*args, **kwargs)
validator(result)
return result

return wrapper

return decorator


def _check_input(validator: type, input_index: int = 0):
"""Return a decorator that validates a specific input of a function.

By default, the first input is validated. If the input index is
larger than the number of inputs, no validation is performed.
"""

def decorator(function: Callable) -> Callable:
@wraps(function)
def wrapper(*args, **kwargs):
if len(args) > input_index:
validator(args[input_index])
result = function(*args, **kwargs)
return result

return wrapper

return decorator
Loading