From d00c6b768b3ffb8cd540e3f55c46bf115c4c5a89 Mon Sep 17 00:00:00 2001 From: Manuel Konrad <84141230+manuelkonrad@users.noreply.github.com> Date: Thu, 17 Oct 2024 21:17:29 +0200 Subject: [PATCH] added first draft of dataframe datamodule Signed-off-by: Manuel Konrad <84141230+manuelkonrad@users.noreply.github.com> --- CHANGELOG.md | 2 + configs/data/dataframe.yaml | 77 ++++ .../guides/reference/data/image/dataframe.md | 7 + .../guides/reference/data/image/index.md | 7 + src/anomalib/data/__init__.py | 3 +- src/anomalib/data/image/__init__.py | 4 +- src/anomalib/data/image/dataframe.py | 361 ++++++++++++++++++ tests/conftest.py | 6 +- tests/unit/data/image/test_dataframe.py | 104 +++++ 9 files changed, 566 insertions(+), 5 deletions(-) create mode 100644 configs/data/dataframe.yaml create mode 100644 docs/source/markdown/guides/reference/data/image/dataframe.md create mode 100644 src/anomalib/data/image/dataframe.py create mode 100644 tests/unit/data/image/test_dataframe.py diff --git a/CHANGELOG.md b/CHANGELOG.md index dedec2f441..5ec4df8297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- 🚀 Add Dataframe datamodule by @manuelkonrad in https://github.com/openvinotoolkit/anomalib/pull/2403 + ### Changed ### Deprecated diff --git a/configs/data/dataframe.yaml b/configs/data/dataframe.yaml new file mode 100644 index 0000000000..b5d02212af --- /dev/null +++ b/configs/data/dataframe.yaml @@ -0,0 +1,77 @@ +class_path: anomalib.data.Dataframe +init_args: + name: bottle + root: "datasets/MVTec/bottle" + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + task: segmentation + transform: null + train_transform: null + eval_transform: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test + val_split_ratio: 0.5 + seed: null + samples: + - image_path: train/good/000.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/001.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/002.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/003.png + label_index: 0 + mask_path: "" + split: train + - image_path: train/good/004.png + label_index: 0 + mask_path: "" + split: train + - image_path: test/bad/000.png + label_index: 1 + mask_path: ground_truth/bad/000_mask.png + split: test + - image_path: test/bad/002.png + label_index: 1 + mask_path: ground_truth/bad/002_mask.png + split: test + - image_path: test/bad/004.png + label_index: 1 + mask_path: ground_truth/bad/004_mask.png + split: test + - image_path: test/good/000.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/001.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/003.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/bad/001.png + label_index: 1 + mask_path: ground_truth/bad/001_mask.png + split: test + - image_path: test/bad/003.png + label_index: 1 + mask_path: ground_truth/bad/003_mask.png + split: test + - image_path: test/good/002.png + label_index: 0 + mask_path: "" + split: test + - image_path: test/good/004.png + label_index: 0 + mask_path: "" + split: test diff --git a/docs/source/markdown/guides/reference/data/image/dataframe.md b/docs/source/markdown/guides/reference/data/image/dataframe.md new file mode 100644 index 0000000000..e263ef1f5d --- /dev/null +++ b/docs/source/markdown/guides/reference/data/image/dataframe.md @@ -0,0 +1,7 @@ +# Dataframe Data + +```{eval-rst} +.. automodule:: anomalib.data.image.dataframe + :members: + :show-inheritance: +``` diff --git a/docs/source/markdown/guides/reference/data/image/index.md b/docs/source/markdown/guides/reference/data/image/index.md index 2525d0d914..f97dbe1d1e 100644 --- a/docs/source/markdown/guides/reference/data/image/index.md +++ b/docs/source/markdown/guides/reference/data/image/index.md @@ -16,6 +16,13 @@ Learn more about BTech dataset. Learn more about custom folder dataset. ::: +:::{grid-item-card} Dataframe +:link: ./dataframe +:link-type: doc + +Learn more about custom dataframe dataset. +::: + :::{grid-item-card} Kolektor :link: ./kolektor :link-type: doc diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 0ad469ac69..f9d42825cd 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -14,7 +14,7 @@ from .base import AnomalibDataModule, AnomalibDataset from .depth import DepthDataFormat, Folder3D, MVTec3D -from .image import BTech, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, Visa +from .image import BTech, Dataframe, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, Visa from .predict import PredictDataset from .utils import LabelName from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat @@ -70,6 +70,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "VideoDataFormat", "get_datamodule", "BTech", + "Dataframe", "Datumaro", "Folder", "Folder3D", diff --git a/src/anomalib/data/image/__init__.py b/src/anomalib/data/image/__init__.py index 147db09418..17b052e198 100644 --- a/src/anomalib/data/image/__init__.py +++ b/src/anomalib/data/image/__init__.py @@ -9,6 +9,7 @@ from enum import Enum from .btech import BTech +from .dataframe import Dataframe from .datumaro import Datumaro from .folder import Folder from .kolektor import Kolektor @@ -20,6 +21,7 @@ class ImageDataFormat(str, Enum): """Supported Image Dataset Types.""" BTECH = "btech" + DATAFRAME = "dataframe" DATUMARO = "datumaro" FOLDER = "folder" FOLDER_3D = "folder_3d" @@ -29,4 +31,4 @@ class ImageDataFormat(str, Enum): VISA = "visa" -__all__ = ["BTech", "Datumaro", "Folder", "Kolektor", "MVTec", "Visa"] +__all__ = ["BTech", "Dataframe", "Datumaro", "Folder", "Kolektor", "MVTec", "Visa"] diff --git a/src/anomalib/data/image/dataframe.py b/src/anomalib/data/image/dataframe.py new file mode 100644 index 0000000000..8f4b3c6a25 --- /dev/null +++ b/src/anomalib/data/image/dataframe.py @@ -0,0 +1,361 @@ +"""Custom Dataframe Dataset. + +This script creates a custom dataset from a pandas DataFrame. +""" + +# Copyright (C) 2022-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import IO + +import pandas as pd +from torchvision.transforms.v2 import Transform + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import ( + DirType, + LabelName, + Split, + TestSplitMode, + ValSplitMode, +) + + +def make_dataframe_dataset( + samples: dict | list | pd.DataFrame, + root: str | Path | None = None, + split: str | Split | None = None, +) -> pd.DataFrame: + """Make Folder Dataset. + + Args: + samples (dict | list | pd.DataFrame): Pandas pd.DataFrame or compatible list or dict containing the + dataset information. + root (str | Path | None): Path to the root directory of the dataset. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split (ie., Split.FULL, Split.TRAIN or Split.TEST). + Defaults to ``None``. + + Returns: + pd.DataFrame: an output dataframe containing samples for the requested split (ie., train or test). + + Examples: + Assume that we would like to use this ``make_dataframe_dataset`` to create a dataset from a pd.DataFrame. + We could then create the dataset as follows, + + .. code-block:: python + + folder_df = make_dataframe_dataset( + samples=input_df, + split="train", + ) + folder_df.head() + + .. code-block:: bash + + image_path label label_index mask_path split + 0 ./toy/good/00.jpg DirType.NORMAL 0 Split.TRAIN + 1 ./toy/good/01.jpg DirType.NORMAL 0 Split.TRAIN + 2 ./toy/good/02.jpg DirType.NORMAL 0 Split.TRAIN + 3 ./toy/good/03.jpg DirType.NORMAL 0 Split.TRAIN + 4 ./toy/good/04.jpg DirType.NORMAL 0 Split.TRAIN + """ + # Convert to pandas pd.DataFrame if dictionary or list is given + if isinstance(samples, dict | list): + samples = pd.DataFrame(samples) + + samples = samples.sort_values(by="image_path", ignore_index=True) + + # Create label column for folder datamodule compatibility + samples.label_index = samples.label_index.astype("Int64") + if "label" not in samples.columns: + samples.loc[ + (samples.label_index == LabelName.NORMAL) & (samples.split == Split.TRAIN), + "label", + ] = DirType.NORMAL + samples.loc[ + (samples.label_index == LabelName.NORMAL) & (samples.split == Split.TEST), + "label", + ] = DirType.NORMAL_TEST + samples.loc[ + (samples.label_index == LabelName.ABNORMAL), + "label", + ] = DirType.ABNORMAL + + # Check if anomalous samples are in training set + if len(samples[(samples.label_index == LabelName.ABNORMAL) & (samples.split == Split.TRAIN)]) != 0: + msg = "Training set must not contain anomalous samples." + raise MisMatchError(msg) + + # Add mask_path column if not exists + if "mask_path" not in samples.columns: + samples["mask_path"] = "" + samples.loc[samples["mask_path"].isna(), "mask_path"] = "" + + # Add root to paths + if root: + samples["image_path"] = samples["image_path"].map(lambda x: Path(root, x)) + samples.loc[ + samples["mask_path"] != "", + "mask_path", + ] = samples.loc[samples["mask_path"] != "", "mask_path"].map(lambda x: Path(root, x)) + samples = samples.astype({"image_path": "str", "mask_path": "str", "label": "str"}) + + # Get the dataframe for the split. + if split: + samples = samples[samples.split == split] + samples = samples.reset_index(drop=True) + + return samples + + +class DataframeDataset(AnomalibDataset): + """Dataframe dataset. + + This class is used to create a dataset from a pd.DataFrame. The class utilizes the Torch Dataset class. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. + task (TaskType): Task type. (``classification``, ``detection`` or ``segmentation``). + samples (dict | list | pd.DataFrame): Pandas pd.DataFrame or compatible list or dict containing the + dataset information. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + normal_dir (str | Path | Sequence): Path to the directory containing normal images. + root (str | Path | None): Root folder of the dataset. + Defaults to ``None``. + split (str | Split | None): Fixed subset split that follows from folder structure on file system. + Choose from [Split.FULL, Split.TRAIN, Split.TEST] + Defaults to ``None``. + + Raises: + ValueError: When task is set to classification and `mask_dir` is provided. When `mask_dir` is + provided, `task` should be set to `segmentation`. + + Examples: + Assume that we would like to use this ``DataframeDataset`` to create a dataset from a pd.DataFrame for + a classification task. We could first create the transforms, + + >>> from anomalib.data.utils import InputNormalizationMethod, get_transforms + >>> transform = get_transforms(image_size=256, normalization=InputNormalizationMethod.NONE) + + We could then create the dataset as follows, + + .. code-block:: python + + dataframe_dataset_classification_train = DataframeDataset( + samples=input_df, + split="train", + transform=transform, + task=TaskType.CLASSIFICATION, + ) + + """ + + def __init__( + self, + name: str, + task: TaskType, + samples: dict | list | pd.DataFrame, + transform: Transform | None = None, + root: str | Path | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task, transform) + + self._name = name + self.root = root + self.split = split + self.samples = make_dataframe_dataset( + samples=samples, + root=self.root, + split=self.split, + ) + + @property + def name(self) -> str: + """Name of the dataset. + + Dataframe dataset overrides the name property to provide a custom name. + """ + return self._name + + +class Dataframe(AnomalibDataModule): + """Dataframe DataModule. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, especially when logging/saving. + samples (dict | list | pd.DataFrame): Pandas pd.DataFrame or compatible list or dict containing the + dataset information. + root (str | Path | None): Path to the root folder containing normal and abnormal dirs. + Defaults to ``None``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Validation, test and predict batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task (TaskType, optional): Task type. Could be ``classification``, ``detection`` or ``segmentation``. + Defaults to ``segmentation``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_TEST``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed used during random subset splitting. + Defaults to ``None``. + + Examples: + The following code demonstrates how to use the ``Dataframe`` datamodule. Assume that the pandas pd.DataFrame + ``input_df`` is structured as follows: + + .. code-block:: bash + + image_path label_index mask_path split + 0 ./toy/good/00.jpg 0 Split.TRAIN + 1 ./toy/good/01.jpg 0 Split.TRAIN + 2 ./toy/good/02.jpg 0 Split.TRAIN + 3 ./toy/good/03.jpg 0 Split.TRAIN + 4 ./toy/good/04.jpg 0 Split.TRAIN + + .. code-block:: python + + dataframe_datamodule = Dataframe( + "my_dataset", + samples=input_df, + root=dataset_root, + task=TaskType.SEGMENTATION, + image_size=256, + normalization=InputNormalizationMethod.NONE, + ) + dataframe_datamodule.setup() + + To access the training images, + + .. code-block:: python + + >> i, data = next(enumerate(dataframe_datamodule.train_dataloader())) + >> print(data.keys(), data["image"].shape) + + To access the test images, + + .. code-block:: python + + >> i, data = next(enumerate(dataframe_datamodule.test_dataloader())) + >> print(data.keys(), data["image"].shape) + """ + + def __init__( + self, + name: str, + samples: dict | list | pd.DataFrame, + root: str | Path | None = None, + normal_split_ratio: float = 0.2, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType | str = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + self._name = name + self.root = root + self._unprocessed_samples = samples + self.task = TaskType(task) + test_split_mode = TestSplitMode(test_split_mode) + val_split_mode = ValSplitMode(val_split_mode) + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + seed=seed, + ) + + self.normal_split_ratio = normal_split_ratio + + def _setup(self, _stage: str | None = None) -> None: + self.train_data = DataframeDataset( + name=self.name, + task=self.task, + samples=self._unprocessed_samples, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + ) + + self.test_data = DataframeDataset( + name=self.name, + task=self.task, + samples=self._unprocessed_samples, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + ) + + @property + def name(self) -> str: + """Name of the datamodule. + + Dataframe datamodule overrides the name property to provide a custom name. + """ + return self._name + + @classmethod + def from_file( + cls: type["Dataframe"], + name: str, + file_path: str | Path | IO[str] | IO[bytes], + file_format: str = "csv", + pd_kwargs: dict | None = None, + **kwargs, + ) -> "Dataframe": + """Make Dataframe Datamodule from csv file. + + Args: + name (str): Name of the dataset. This is used to name the datamodule, + especially when logging/saving. + file_path (str | Path | file-like): Path or file-like object to tabular + file containing the datset information. + file_format (str): File format supported by a pd.read_* method, such + as ``csv``, ``parquet`` or ``json``. + Defaults to ``csv``. + pd_kwargs (dict | None): Keyword argument dictionary for the pd.read_* method. + Defaults to ``None``. + kwargs (dict): Additional keyword arguments for the Dataframe Datamodule class. + + Returns: + Dataframe: Dataframe Datamodule + """ + pd_kwargs = pd_kwargs or {} + samples = getattr(pd, f"read_{file_format}")(file_path, **pd_kwargs) + return cls(name, samples, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index a9db6c1d3d..ddc03382ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,9 +56,9 @@ def dataset_path(project_path: Path) -> Path: # 1. Create the dummy image datasets. for data_format in list(ImageDataFormat): - # Do not generate a dummy dataset for folder datasets. - # We could use one of these datasets to test the folders datasets. - if not data_format.value.startswith("folder"): + # Do not generate a dummy dataset for folder or dataframe datasets. + # We could use one of these datasets to test the folders and dataframe datasets. + if not data_format.value.startswith(("folder", "dataframe")): dataset_generator = DummyImageDatasetGenerator(data_format=data_format, root=_dataset_path) dataset_generator.generate_dataset() diff --git a/tests/unit/data/image/test_dataframe.py b/tests/unit/data/image/test_dataframe.py new file mode 100644 index 0000000000..cd7e3f8c69 --- /dev/null +++ b/tests/unit/data/image/test_dataframe.py @@ -0,0 +1,104 @@ +"""Unit Tests - Dataframe Datamodule.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pandas as pd +import pytest + +from anomalib import TaskType +from anomalib.data import Dataframe, Folder +from tests.unit.data.base.image import _TestAnomalibImageDatamodule + + +class TestDataframe(_TestAnomalibImageDatamodule): + """Dataframe Datamodule Unit Tests. + + All of the Dataframe datamodule tests are placed in ``TestDataframe`` class. + """ + + @staticmethod + def get_samples_dataframe(dataset_path: Path, task_type: TaskType) -> pd.DataFrame: + """Create samples pd.DataFrame.""" + # Make sure to use a mask directory for segmentation. Dataframe datamodule + # expects a relative directory to the root. + mask_dir = None if task_type == TaskType.CLASSIFICATION else "ground_truth/bad" + + # Create folder datamodule to get samples dataframe + _folder_datamodule = Folder( + name="dummy", + root=dataset_path / "mvtec" / "dummy", + normal_dir="train/good", + abnormal_dir="test/bad", + normal_test_dir="test/good", + mask_dir=mask_dir, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + task=task_type, + ) + _folder_datamodule.setup() + _samples = pd.concat([ + _folder_datamodule.train_data.samples, + _folder_datamodule.test_data.samples, + _folder_datamodule.val_data.samples, + ]) + + # drop label column as it is inferred from the other columns + return _samples.drop(["label"], axis="columns") + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path, task_type: TaskType) -> Dataframe: + """Create and return a Dataframe datamodule.""" + # Create and prepare the dataset + _samples = TestDataframe.get_samples_dataframe(dataset_path, task_type) + _datamodule = Dataframe( + name="dummy", + samples=_samples, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + task=task_type, + ) + _datamodule.setup() + + return _datamodule + + @pytest.fixture() + @staticmethod + def fxt_data_config_path() -> str: + """Return the path to the test data config.""" + return "configs/data/dataframe.yaml" + + +class TestDataframeFromFile(TestDataframe): + """Dataframe Datamodule Unit Tests for alternative constructor. + + Tests for the Datamodule creation from file. + """ + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path, task_type: TaskType) -> Dataframe: + """Create and return a Dataframe datamodule.""" + # Create and prepare the dataset + _samples = TestDataframeFromFile.get_samples_dataframe(dataset_path, task_type) + with tempfile.TemporaryFile() as samples_file: + _samples.to_csv(samples_file) + samples_file.seek(0) + + _datamodule = Dataframe.from_file( + name="dummy", + file_path=samples_file, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + task=task_type, + ) + _datamodule.setup() + + return _datamodule