diff --git a/examples/configs/data/realiad.yaml b/examples/configs/data/realiad.yaml new file mode 100644 index 0000000000..adc7e94d39 --- /dev/null +++ b/examples/configs/data/realiad.yaml @@ -0,0 +1,14 @@ +class_path: anomalib.data.RealIAD +init_args: + root: ./datasets/Real-IAD + category: audiojack + resolution: 256 + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + test_split_mode: none + val_split_mode: same_as_test + train_augmentations: null + val_augmentations: null + test_augmentations: null + augmentations: null diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 42d8e15902..447cf0b30e 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -49,7 +49,19 @@ # Datamodules from .datamodules.base import AnomalibDataModule from .datamodules.depth import DepthDataFormat, Folder3D, MVTec3D -from .datamodules.image import VAD, BTech, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, MVTecAD, MVTecLOCO, Visa +from .datamodules.image import ( + VAD, + BTech, + Datumaro, + Folder, + ImageDataFormat, + Kolektor, + MVTec, + MVTecAD, + MVTecLOCO, + RealIAD, + Visa, +) from .datamodules.video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat # Datasets @@ -146,40 +158,46 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "NumpyVideoItem", "VideoBatch", "VideoItem", - # Depth + # Data Formats + "DataFormat", "DepthDataFormat", + "ImageDataFormat", + "VideoDataFormat", + # Depth Data Modules "Folder3D", - "Folder3DDataset", "MVTec3D", - "MVTec3DDataset", - # Image + # Image Data Modules "BTech", - "BTechDataset", "Datumaro", - "DatumaroDataset", "Folder", - "FolderDataset", - "ImageDataFormat", "Kolektor", - "KolektorDataset", - "MVTec", # Include MVTec for backward compatibility + "MVTec", "MVTecAD", - "MVTecADDataset", "MVTecLOCO", - "MVTecLOCODataset", + "RealIAD", "VAD", - "VADDataset", "Visa", - "VisaDataset", - # Video + # Video Data Modules "Avenue", - "AvenueDataset", "ShanghaiTech", - "ShanghaiTechDataset", "UCSDped", + # Datasets + "Folder3DDataset", + "MVTec3DDataset", + "BTechDataset", + "DatumaroDataset", + "FolderDataset", + "KolektorDataset", + "MVTecADDataset", + "MVTecLOCODataset", + "VADDataset", + "VisaDataset", + "AvenueDataset", + "ShanghaiTechDataset", "UCSDpedDataset", - "VideoDataFormat", - # Predict "PredictDataset", + # Functions "get_datamodule", + # Exceptions + "UnknownDatamoduleError", ] diff --git a/src/anomalib/data/datamodules/image/__init__.py b/src/anomalib/data/datamodules/image/__init__.py index fb0d6a07e1..c37321cdc6 100644 --- a/src/anomalib/data/datamodules/image/__init__.py +++ b/src/anomalib/data/datamodules/image/__init__.py @@ -33,6 +33,7 @@ from .kolektor import Kolektor from .mvtec_loco import MVTecLOCO from .mvtecad import MVTec, MVTecAD +from .realiad import RealIAD from .vad import VAD from .visa import Visa @@ -50,6 +51,7 @@ class ImageDataFormat(str, Enum): - ``MVTEC_AD``: MVTec AD Dataset - ``MVTEC_3D``: MVTec 3D AD Dataset - ``MVTEC_LOCO``: MVTec LOCO Dataset + - ``REALIAD``: Real-IAD Dataset - ``VAD``: Valeo Anomaly Detection Dataset - ``VISA``: Visual Anomaly Dataset """ @@ -62,6 +64,7 @@ class ImageDataFormat(str, Enum): MVTEC_AD = "mvtecad" MVTEC_3D = "mvtec_3d" MVTEC_LOCO = "mvtec_loco" + REAL_IAD = "realiad" VAD = "vad" VISA = "visa" @@ -71,9 +74,10 @@ class ImageDataFormat(str, Enum): "Datumaro", "Folder", "Kolektor", - "MVTecAD", "MVTec", # Include both for backward compatibility + "MVTecAD", "MVTecLOCO", + "RealIAD", "VAD", "Visa", ] diff --git a/src/anomalib/data/datamodules/image/realiad.py b/src/anomalib/data/datamodules/image/realiad.py new file mode 100644 index 0000000000..d3c1d8e0a5 --- /dev/null +++ b/src/anomalib/data/datamodules/image/realiad.py @@ -0,0 +1,312 @@ +"""Real-IAD Data Module. + +This module provides a PyTorch Lightning DataModule for the Real-IAD dataset. + +The Real-IAD dataset is a large-scale industrial anomaly detection dataset containing +30 categories of industrial objects with both normal and anomalous samples. Each object +is captured from 5 different camera viewpoints (C1-C5). + +Dataset Structure: + The dataset follows this directory structure: + Real-IAD/ + ├── realiad_256/ # 256x256 resolution images + ├── realiad_512/ # 512x512 resolution images + ├── realiad_1024/ # 1024x1024 resolution images + └── realiad_jsons/ # JSON metadata files + ├── realiad_jsons/ # Base metadata + ├── realiad_jsons_sv/ # Single-view metadata + └── realiad_jsons_fuiad/ # FUIAD metadata versions + +Example: + Create a Real-IAD datamodule:: + + >>> from anomalib.data import RealIAD + >>> datamodule = RealIAD( + ... root="./datasets/Real-IAD", + ... category="audiojack", + ... resolution="1024" + ... ) + +Notes: + The dataset should be downloaded manually from Hugging Face and placed in the + appropriate directory. See ``DOWNLOAD_INSTRUCTIONS`` for detailed steps. + +License: + Real-IAD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0). + https://creativecommons.org/licenses/by-nc-sa/4.0/ +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from textwrap import dedent + +from torchvision.transforms.v2 import Transform + +from anomalib.data.datamodules.base.image import AnomalibDataModule +from anomalib.data.datasets.image.realiad import CATEGORIES, RESOLUTIONS, RealIADDataset +from anomalib.data.utils import Split, TestSplitMode, ValSplitMode + + +class RealIAD(AnomalibDataModule): + """Real-IAD Datamodule. + + Args: + root (Path | str): Path to root directory containing the dataset. + Defaults to ``"./datasets/Real-IAD"``. + category (str): Category of the Real-IAD dataset (e.g. ``"audiojack"`` or + ``"button_battery"``). Defaults to ``"audiojack"``. + resolution (str | int): Image resolution to use (e.g. ``"256"``, ``"512"``, + ``"1024"``, ``"raw"`` or their integer equivalents). + For example, both "256" and 256 are valid. Defaults to ``256``. + json_path (str | Path): Path to JSON metadata file, relative to root directory. + Can use {category} placeholder which will be replaced with the category name. + Common paths are: + - "realiad_jsons/realiad_jsons/{category}.json" - Base metadata (multi-view) + - "realiad_jsons/realiad_jsons_sv/{category}.json" - Single-view metadata + - "realiad_jsons/realiad_jsons_fuiad_0.4/{category}.json" - FUIAD v0.4 metadata + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply to the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. + test_split_mode (TestSplitMode): Method to create test set. + Defaults to ``TestSplitMode.NONE``. + val_split_mode (ValSplitMode): Method to create validation set. + Defaults to ``ValSplitMode.SAME_AS_TEST``. + seed (int | None, optional): Seed for reproducibility. + Defaults to ``None``. + + Example: + Create Real-IAD datamodule with default settings:: + + >>> datamodule = RealIAD() + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + Change the category and resolution:: + + >>> # Using string resolution + >>> datamodule = RealIAD( + ... category="button_battery", + ... resolution="512" + ... ) + + >>> # Using integer resolution + >>> datamodule = RealIAD( + ... category="button_battery", + ... resolution=1024 + ... ) + + Use different JSON metadata files:: + + >>> # Base metadata (multi-view) + >>> datamodule = RealIAD( + ... json_path="realiad_jsons/realiad_jsons/{category}.json" + ... ) + + >>> # Single-view metadata + >>> datamodule = RealIAD( + ... json_path="realiad_jsons/realiad_jsons_sv/{category}.json" + ... ) + + >>> # FUIAD v0.4 metadata (filtered subset) + >>> datamodule = RealIAD( + ... json_path="realiad_jsons/realiad_jsons_fuiad_0.4/{category}.json" + ... ) + + >>> # Custom metadata + >>> datamodule = RealIAD( + ... json_path="path/to/custom/metadata.json" + ... ) + + Create validation set from test data:: + + >>> datamodule = RealIAD( + ... val_split_mode=ValSplitMode.FROM_TEST, + ... val_split_ratio=0.1 + ... ) + + Notes: + - The dataset contains both normal (OK) and anomalous (NG) samples + - Each object is captured from 5 different camera viewpoints (C1-C5) + - Images are available in multiple resolutions (256x256, 512x512, 1024x1024) + - JSON metadata files provide additional information and different dataset splits + - Segmentation masks are provided for anomalous samples + """ + + def __init__( + self, + root: Path | str = "./datasets/Real-IAD", + category: str = "audiojack", + resolution: str | int = 256, + json_path: str | Path = "realiad_jsons/realiad_jsons/{category}.json", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, + test_split_mode: TestSplitMode | str = TestSplitMode.NONE, + val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, + test_split_mode=test_split_mode, + val_split_mode=val_split_mode, + seed=seed, + ) + + self.root = Path(root) + self.category = category + + # Convert resolution to string if it's an integer + if isinstance(resolution, int): + resolution = str(resolution) + + self.resolution = resolution + self.json_path = json_path + + # Validate inputs + if category not in CATEGORIES: + msg = f"Category {category} not found in Real-IAD dataset. Available categories: {CATEGORIES}" + raise ValueError(msg) + + if resolution not in RESOLUTIONS: + msg = f"Resolution {resolution} not found in Real-IAD dataset. Available resolutions: {RESOLUTIONS}" + raise ValueError(msg) + + def prepare_data(self) -> None: + """Verify that the dataset is available and provide download instructions. + + This method checks if the dataset exists in the root directory. If not, it provides + instructions for requesting access and downloading from Hugging Face. + + The Real-IAD dataset is available at: + https://huggingface.co/datasets/REAL-IAD/Real-IAD + + Note: + The dataset requires approval from the authors. You need to: + 1. Create a Hugging Face account + 2. Request access to the dataset + 3. Wait for approval + 4. Download and extract to the root directory + """ + root_path = Path(self.root) + required_dirs = [root_path / f"realiad_{res}" for res in RESOLUTIONS] + [ + root_path / "realiad_jsons", + root_path / "realiad_jsons_sv", + root_path / "realiad_jsons_fuiad_0.0", + root_path / "realiad_jsons_fuiad_0.1", + root_path / "realiad_jsons_fuiad_0.2", + root_path / "realiad_jsons_fuiad_0.4", + ] + + if not any(d.exists() for d in required_dirs): + raise RuntimeError(get_download_instructions(root_path)) + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform dynamic subset splitting.""" + self.train_data = RealIADDataset( + split=Split.TRAIN, + root=self.root, + category=self.category, + resolution=self.resolution, + json_path=self.json_path, + ) + self.test_data = RealIADDataset( + split=Split.TEST, + root=self.root, + category=self.category, + resolution=self.resolution, + json_path=self.json_path, + ) + + +def get_download_instructions(root_path: Path) -> str: + """Get download instructions for the Real-IAD dataset. + + Args: + root_path: Path where the dataset should be downloaded. + + Returns: + str: Formatted download instructions. + """ + return dedent(f""" + Real-IAD dataset not found in {root_path} + + The Real-IAD dataset requires approval from the authors. To get access: + + 1. Create a Hugging Face account at https://huggingface.co + 2. Visit https://huggingface.co/datasets/REAL-IAD/Real-IAD + 3. Click "Access Repository" and fill out the form + 4. Wait for approval from the dataset authors + 5. Once approved, you have two options to download the dataset: + + Option 1: Using Hugging Face CLI (Recommended) + -------------------------------------------- + a. Install the Hugging Face CLI: + pip install huggingface_hub + + b. Login to Hugging Face: + huggingface-cli login + + c. Download the dataset: + huggingface-cli download \ + --repo-type dataset \ + --local-dir {root_path} REAL-IAD/Real-IAD \ + --include="*" \ + --token YOUR_HF_TOKEN + + Option 2: Manual Download + ----------------------- + a. Visit https://huggingface.co/datasets/REAL-IAD/Real-IAD + b. Download all files manually + c. Extract the contents to: {root_path} + + Expected directory structure: + {root_path}/ + ├── realiad_256/ # 256x256 resolution images + ├── realiad_512/ # 512x512 resolution images + ├── realiad_1024/ # 1024x1024 resolution images + ├── realiad_raw/ # Original resolution images + └── realiad_jsons/ # Base JSON metadata + ├── realiad_jsons_sv/ # Single-view JSON metadata + ├── realiad_jsons_fuiad_0.0/ # FUIAD v0.0 metadata + ├── realiad_jsons_fuiad_0.1/ # FUIAD v0.1 metadata + ├── realiad_jsons_fuiad_0.2/ # FUIAD v0.2 metadata + └── realiad_jsons_fuiad_0.4/ # FUIAD v0.4 metadata + + Note: Replace YOUR_HF_TOKEN with your Hugging Face access token. + To get your token, visit: https://huggingface.co/settings/tokens + + For more information about the dataset, see: + - Paper: https://arxiv.org/abs/2401.02749 + - Code: https://github.com/REAL-IAD/REAL-IAD + - Dataset: https://huggingface.co/datasets/REAL-IAD/Real-IAD + """) diff --git a/src/anomalib/data/datasets/image/__init__.py b/src/anomalib/data/datasets/image/__init__.py index 49d176d1b0..89f7b09a13 100644 --- a/src/anomalib/data/datasets/image/__init__.py +++ b/src/anomalib/data/datasets/image/__init__.py @@ -30,6 +30,7 @@ from .kolektor import KolektorDataset from .mvtec_loco import MVTecLOCODataset from .mvtecad import MVTecADDataset, MVTecDataset +from .realiad import RealIADDataset from .vad import VADDataset from .visa import VisaDataset @@ -39,8 +40,9 @@ "FolderDataset", "KolektorDataset", "MVTecADDataset", + "MVTecDataset", "MVTecLOCODataset", + "RealIADDataset", "VADDataset", "VisaDataset", - "MVTecDataset", ] diff --git a/src/anomalib/data/datasets/image/realiad.py b/src/anomalib/data/datasets/image/realiad.py new file mode 100644 index 0000000000..2b478a0a5c --- /dev/null +++ b/src/anomalib/data/datasets/image/realiad.py @@ -0,0 +1,291 @@ +"""Real-IAD Dataset. + +This module provides PyTorch Dataset implementation for the Real-IAD dataset. +The dataset contains 30 categories of industrial objects with both normal and +anomalous samples, captured from 5 different camera viewpoints. + +Dataset Structure: + The dataset follows this directory structure: + Real-IAD/ + ├── realiad_256/ # 256x256 resolution images + │ └── CATEGORY/ # e.g. audiojack, button_battery, etc. + │ ├── OK/ # Normal samples + │ │ └── SXXXX/ # Sample ID + │ │ └── CATEGORY_XXXX_OK_CX_TIMESTAMP.jpg + │ └── NG/ # Anomalous samples + │ └── DEFECT_TYPE/ # Type of defect + │ └── SXXXX/ + │ ├── CATEGORY_XXXX_NG_CX_TIMESTAMP.jpg + │ └── CATEGORY_XXXX_NG_CX_TIMESTAMP_mask.png + ├── realiad_512/ # 512x512 resolution images + ├── realiad_1024/ # 1024x1024 resolution images + └── realiad_jsons/ # JSON metadata files + ├── realiad_jsons/ # Base metadata (multi-view) + ├── realiad_jsons_sv/ # Single-view metadata + └── realiad_jsons_fuiad/ # FUIAD metadata versions + +License: + Real-IAD is released under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Sequence +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib.data.datasets.base import AnomalibDataset +from anomalib.data.utils import LabelName, Split, validate_path + +IMG_EXTENSIONS = (".jpg", ".png", ".PNG", ".JPG") +RESOLUTIONS = ("256", "512", "1024", "raw") +CATEGORIES = ( + "audiojack", + "button_battery", + "capacitor", + "connector", + "diode", + "end_cap", + "fuse", + "ic", + "inductor", + "led", + "pcb_finger", + "plastic_nut", + "potentiometer", + "relay", + "resistor", + "rivet", + "rubber_grommet", + "screw", + "spring", + "switch", + "terminal_block", + "through_hole", + "toggle_switch", + "toy_brick", + "transistor", + "washer", + "woodstick", + "zipper", + "toothbrush", + "usb_adaptor", +) + + +class RealIADDataset(AnomalibDataset): + """Real-IAD dataset class. + + Dataset class for loading and processing Real-IAD dataset images. Supports + both classification and segmentation tasks, with multi-view capabilities. + + The dataset provides: + - 30 industrial object categories + - 5 camera viewpoints per object (C1-C5) + - Multiple image resolutions (256x256, 512x512, 1024x1024) + - Segmentation masks for anomalous samples + - JSON metadata for flexible dataset organization + + Args: + root (Path | str): Path to root directory containing the dataset. + Defaults to ``"./datasets/Real-IAD"``. + category (str): Category name, must be one of ``CATEGORIES``. + Defaults to ``"audiojack"``. + resolution (str | int): Image resolution, must be one of ``RESOLUTIONS`` or their integer equivalents. + For example, both "256" and 256 are valid. Defaults to ``256``. + augmentations (Transform, optional): Augmentations that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None, optional): Dataset split - usually + ``Split.TRAIN`` or ``Split.TEST``. Defaults to ``None``. + json_path (str | Path): Path to JSON metadata file, relative to root directory. + Can use {category} placeholder which will be replaced with the category name. + Common paths are: + - "realiad_jsons/realiad_jsons/{category}.json" - Base metadata (multi-view) + - "realiad_jsons/realiad_jsons_sv/{category}.json" - Single-view metadata + - "realiad_jsons/realiad_jsons_fuiad_0.4/{category}.json" - FUIAD v0.4 metadata + + Example: + >>> from pathlib import Path + >>> from anomalib.data.datasets import RealIADDataset + + >>> # Using base JSON metadata (multi-view) with string resolution + >>> dataset = RealIADDataset( + ... root=Path("./datasets/Real-IAD"), + ... category="audiojack", + ... resolution="1024", + ... split="train", + ... json_path="realiad_jsons/realiad_jsons/audiojack.json" + ... ) + + >>> # Using integer resolution + >>> dataset = RealIADDataset( + ... category="button_battery", + ... resolution=512 + ... ) + + >>> # Using single-view metadata + >>> dataset = RealIADDataset( + ... json_path="realiad_jsons/realiad_jsons_sv/audiojack.json" + ... ) + + >>> # Using FUIAD v0.4 metadata (filtered subset) + >>> dataset = RealIADDataset( + ... json_path="realiad_jsons/realiad_jsons_fuiad_0.4/audiojack.json" + ... ) + + >>> # Using custom JSON file + >>> dataset = RealIADDataset( + ... json_path="path/to/custom/metadata.json" + ... ) + + Notes: + - Normal samples are in the 'OK' directory, anomalous in 'NG' + - Each sample has a unique ID (SXXXX) and camera view (CX) + - Anomalous samples include defect type and segmentation masks + - JSON metadata provides flexible dataset organization + - The task (classification/segmentation) is determined by mask availability + """ + + def __init__( + self, + root: Path | str = "./datasets/Real-IAD", + category: str = "audiojack", + resolution: str | int = 256, + augmentations: Transform | None = None, + split: str | Split | None = None, + json_path: str | Path = "realiad_jsons/realiad_jsons/{category}.json", + ) -> None: + """Initialize RealIAD dataset. + + Args: + root: Path to root directory containing the dataset. + category: Category name, must be one of ``CATEGORIES``. + resolution: Image resolution, must be one of ``RESOLUTIONS`` or their integer equivalents. + For example, both "256" and 256 are valid. + augmentations: Augmentations that should be applied to the input images. + split: Dataset split - usually ``Split.TRAIN`` or ``Split.TEST``. + json_path: Path to JSON metadata file, relative to root directory. + Can use {category} placeholder which will be replaced with the category name. + Common paths are: + - "realiad_jsons/realiad_jsons/{category}.json" - Base metadata (multi-view) + - "realiad_jsons/realiad_jsons_sv/{category}.json" - Single-view metadata + - "realiad_jsons/realiad_jsons_fuiad_0.4/{category}.json" - FUIAD v0.4 metadata + """ + super().__init__(augmentations=augmentations) + + if category not in CATEGORIES: + msg = f"Category {category} not found in Real-IAD dataset. Available categories: {CATEGORIES}" + raise ValueError(msg) + + # Convert resolution to string if it's an integer + if isinstance(resolution, int): + resolution = str(resolution) + + if resolution not in RESOLUTIONS: + msg = f"Resolution {resolution} not found in Real-IAD dataset. Available resolutions: {RESOLUTIONS}" + raise ValueError(msg) + + self.root = Path(root) + self.category = category + self.resolution = resolution + self.split = split + + # Format json_path if it contains {category} placeholder + if isinstance(json_path, str): + json_path = json_path.format(category=category) + + # Resolve JSON path + json_file = self.root / json_path + + # Load JSON metadata + if not json_file.exists(): + msg = f"JSON metadata file not found at {json_file}" + raise FileNotFoundError(msg) + + with json_file.open(encoding="utf-8") as f: + self.metadata = json.load(f) + + # Validate JSON structure + if not isinstance(self.metadata, dict) or not any(key in self.metadata for key in ["train", "test"]): + msg = f"Invalid JSON structure in {json_file}. Must contain 'train' and/or 'test' keys." + raise ValueError(msg) + + # Construct the path to the category directory based on resolution + self.root_category = self.root / f"realiad_{resolution}" / category + + # Create dataset samples + self.samples = make_realiad_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + metadata=self.metadata, + ) + + +def make_realiad_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] | None = None, + metadata: dict | None = None, +) -> DataFrame: + """Create Real-IAD samples by parsing the JSON metadata. + + Args: + root (Path | str): Path to dataset root directory + split (str | Split | None, optional): Dataset split (train or test) + Defaults to ``None``. + extensions (Sequence[str] | None, optional): Valid file extensions + Defaults to ``None``. + metadata (dict | None, optional): JSON metadata containing dataset organization. + Defaults to ``None``. + + Returns: + DataFrame: Dataset samples with columns: + - image_path: Path to image file + - mask_path: Path to mask file (if available) + - label_index: Numeric label (0=normal, 1=abnormal) + - split: Dataset split (train/test) + """ + if extensions is None: + extensions = IMG_EXTENSIONS + + root = validate_path(root) + + if metadata is None: + msg = "JSON metadata is required for RealIAD dataset" + raise ValueError(msg) + + samples_list = [] + + # Use train/test splits from JSON metadata + if split is not None: + split_key = "train" if split == Split.TRAIN else "test" + if split_key not in metadata: + msg = f"Split {split_key} not found in JSON metadata" + raise ValueError(msg) + samples = metadata[split_key] + else: + # If no split specified, use all samples + samples = metadata.get("train", []) + metadata.get("test", []) + + for sample in samples: + # Create sample data with only essential columns + sample_data = { + "image_path": str(root / sample["image_path"]), + "mask_path": str(root / sample["mask_path"]) if sample.get("mask_path") else "", + "label_index": LabelName.NORMAL if sample["anomaly_class"] == "OK" else LabelName.ABNORMAL, + "split": "train" if sample in metadata.get("train", []) else "test", + } + samples_list.append(sample_data) + + samples = DataFrame(samples_list) + + # Set task type + samples.attrs["task"] = "classification" if (samples["mask_path"] == "").all() else "segmentation" + + return samples diff --git a/tests/conftest.py b/tests/conftest.py index 281dbb5d54..5069f06378 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,6 +62,10 @@ def dataset_path(project_path: Path) -> Path: dataset_generator = DummyImageDatasetGenerator(data_format=data_format, root=_dataset_path) dataset_generator.generate_dataset() + # Generate RealIAD dataset separately since it has a unique format + dataset_generator = DummyImageDatasetGenerator(data_format="realiad", root=_dataset_path) + dataset_generator.generate_dataset() + # 2. Create the dummy video datasets. for data_format in list(VideoDataFormat): dataset_generator = DummyVideoDatasetGenerator(data_format=data_format, root=_dataset_path) diff --git a/tests/helpers/data.py b/tests/helpers/data.py index eafcf723a1..d087ac1178 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -496,12 +496,74 @@ def _generate_dummy_kolektor_dataset(self) -> None: mask_filename = self.dataset_root / category / f"Part{i}_label.bmp" self.image_generator.generate_image(label, image_filename, mask_filename) - def _generate_dummy_visa_dataset(self) -> None: - """Generate dummy Visa dataset in directory using the same convention as Visa AD.""" - # Visa dataset on anomalib follows the same convention as MVTec AD. - # The only difference is that the root directory has a subdirectory called "visa_pytorch". - self.dataset_root = self.dataset_root.parent / "visa_pytorch" - self._generate_dummy_mvtecad_dataset(normal_dir="good", abnormal_dir="bad", image_extension=".jpg") + def _generate_dummy_realiad_dataset(self) -> None: + """Generate dummy RealIAD dataset in directory using the same convention as RealIAD.""" + import json + + # Create the resolution directory + resolution_dir = self.dataset_root / "realiad_256" + resolution_dir.mkdir(parents=True, exist_ok=True) + + # Create category directory + category = "audiojack" + category_dir = resolution_dir / category + category_dir.mkdir(parents=True, exist_ok=True) + + # Create jsons directory structure + jsons_dir = self.dataset_root / "realiad_jsons" / "realiad_jsons" + jsons_dir.mkdir(parents=True, exist_ok=True) + + # Generate images and create metadata + metadata = {"train": [], "test": []} + image_generator = DummyImageGenerator(image_shape=self.image_shape, rng=self.rng) + + # Generate normal train images + for i in range(self.num_train): + image, _ = image_generator.generate_normal_image() + filename = f"{category}_{i:04d}_OK_C0_0000.png" + image_path = category_dir / filename + image_generator.save_image(image_path, image) + + # Add to metadata - note: these are relative paths from category dir + metadata["train"].append({ + "image_path": filename, + "mask_path": "", + "anomaly_class": "OK", + "camera_view": "C0", + "timestamp": "0000", + }) + + # Generate abnormal test images with masks + for i in range(self.num_test): + image, mask = image_generator.generate_abnormal_image() + + # Save abnormal images + filename = f"{category}_{i:04d}_NG_C0_0000.png" + mask_filename = f"{category}_{i:04d}_NG_C0_0000_mask.png" + + image_path = category_dir / filename + mask_path = category_dir / mask_filename + + image_generator.save_image(image_path, image) + + # Convert mask to uint8 before saving + # Ensure mask is in range [0, 255] + mask = (mask * 255).astype(np.uint8) + image_generator.save_image(mask_path, mask) + + # Add to metadata - note: these are relative paths from category dir + metadata["test"].append({ + "image_path": filename, + "mask_path": mask_filename, + "anomaly_class": "NG", + "camera_view": "C0", + "timestamp": "0000", + }) + + # Save metadata JSON file + json_path = jsons_dir / f"{category}.json" + with json_path.open("w") as f: + json.dump(metadata, f, indent=2) def _generate_dummy_vad_dataset( self, @@ -531,6 +593,13 @@ def _generate_dummy_vad_dataset( image_filename = path / f"{i:03}{image_extension}" self.image_generator.generate_image(label, image_filename) + def _generate_dummy_visa_dataset(self) -> None: + """Generate dummy Visa dataset in directory using the same convention as Visa AD.""" + # Visa dataset on anomalib follows the same convention as MVTec AD. + # The only difference is that the root directory has a subdirectory called "visa_pytorch". + self.dataset_root = self.dataset_root.parent / "visa_pytorch" + self._generate_dummy_mvtecad_dataset(normal_dir="good", abnormal_dir="bad", image_extension=".jpg") + class DummyVideoDatasetGenerator(DummyDatasetGenerator): """Dummy video dataset generator. diff --git a/tests/unit/data/datamodule/image/test_realiad.py b/tests/unit/data/datamodule/image/test_realiad.py new file mode 100644 index 0000000000..e0a466b05e --- /dev/null +++ b/tests/unit/data/datamodule/image/test_realiad.py @@ -0,0 +1,40 @@ +"""Unit Tests - RealIAD Datamodule.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest +from torchvision.transforms.v2 import Resize + +from anomalib.data import RealIAD +from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule + + +class TestRealIAD(_TestAnomalibImageDatamodule): + """RealIAD Datamodule Unit Tests.""" + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path) -> RealIAD: + """Create and return a RealIAD datamodule.""" + _datamodule = RealIAD( + root=dataset_path / "realiad", + category="audiojack", + resolution=256, + train_batch_size=4, + eval_batch_size=4, + num_workers=0, + augmentations=Resize((256, 256)), + ) + _datamodule.prepare_data() + _datamodule.setup() + + return _datamodule + + @pytest.fixture() + @staticmethod + def fxt_data_config_path() -> str: + """Return the path to the test data config.""" + return "examples/configs/data/realiad.yaml"