diff --git a/examples/api/02_data/mvtecad2.py b/examples/api/02_data/mvtecad2.py new file mode 100644 index 0000000000..f241983696 --- /dev/null +++ b/examples/api/02_data/mvtecad2.py @@ -0,0 +1,147 @@ +"""Example showing how to use the MVTec AD 2 dataset with Anomalib. + +This example demonstrates how to: +1. Load and visualize the MVTec AD 2 dataset +2. Create a datamodule and use it for training +3. Access different test sets (public, private, mixed) +4. Work with custom transforms and visualization +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import matplotlib.pyplot as plt +import torch +from torch.utils.data import DataLoader +from torchvision.transforms.v2 import Compose, Resize, ToDtype, ToImage + +from anomalib.data import MVTecAD2 +from anomalib.data.datasets.base.image import ImageItem +from anomalib.data.datasets.image.mvtecad2 import MVTecAD2Dataset, TestType +from anomalib.data.utils import Split + +# 1. Basic Usage +print("1. Basic Usage") +datamodule = MVTecAD2( + root="./datasets/MVTec_AD_2", + category="sheet_metal", + train_batch_size=32, + eval_batch_size=32, + num_workers=8, +) +datamodule.setup() # This will prepare the dataset + +# Print some information about the splits +print(f"Number of training samples: {len(datamodule.train_data)}") +print(f"Number of validation samples: {len(datamodule.val_data)}") +print(f"Number of test samples (public): {len(datamodule.test_public_data)}") +print(f"Number of test samples (private): {len(datamodule.test_private_data)}") +print(f"Number of test samples (private mixed): {len(datamodule.test_private_mixed_data)}") + +# 2. Custom Transforms +print("\n2. Custom Transforms") +transform = Compose([ + ToImage(), + Resize((256, 256)), + ToDtype(torch.float32, scale=True), +]) + +# Create dataset with custom transform +datamodule = MVTecAD2( + root="./datasets/MVTec_AD_2", + category="sheet_metal", + train_augmentations=transform, + val_augmentations=transform, + test_augmentations=transform, +) +datamodule.setup() + +# 3. Different Test Sets +print("\n3. Accessing Different Test Sets") + +# Get loaders for each test set +public_loader = datamodule.test_dataloader(test_type=TestType.PUBLIC) +private_loader = datamodule.test_dataloader(test_type=TestType.PRIVATE) +mixed_loader = datamodule.test_dataloader(test_type=TestType.PRIVATE_MIXED) + +# Get sample batches +public_batch = next(iter(public_loader)) +private_batch = next(iter(private_loader)) +mixed_batch = next(iter(mixed_loader)) + +print("Public test batch shape:", public_batch.image.shape) +print("Private test batch shape:", private_batch.image.shape) +print("Private mixed test batch shape:", mixed_batch.image.shape) + +# 4. Advanced Usage - Direct Dataset Access +print("\n4. Advanced Usage") + +# Create datasets for each split +train_dataset = MVTecAD2Dataset( + root="./datasets/MVTec_AD_2", + category="sheet_metal", + split=Split.TRAIN, + augmentations=transform, +) + +test_dataset = MVTecAD2Dataset( + root="./datasets/MVTec_AD_2", + category="sheet_metal", + split=Split.TEST, + test_type=TestType.PUBLIC, # Use public test set + augmentations=transform, +) + +# Create dataloaders +train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=train_dataset.collate_fn) +test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=test_dataset.collate_fn) + +# Get some sample images +train_samples = next(iter(train_loader)) +test_samples = next(iter(test_loader)) + +print("Train Dataset:") +print(f"- Number of samples: {len(train_dataset)}") +print(f"- Image shape: {train_samples.image.shape}") +print(f"- Labels: {train_samples.gt_label}") + +print("\nTest Dataset:") +print(f"- Number of samples: {len(test_dataset)}") +print(f"- Image shape: {test_samples.image.shape}") +print(f"- Labels: {test_samples.gt_label}") +if hasattr(test_samples, "gt_mask") and test_samples.gt_mask is not None: + print(f"- Mask shape: {test_samples.gt_mask.shape}") + + +# 5. Visualize some samples +def show_samples(samples: ImageItem, title: str) -> None: + """Helper function to display samples.""" + if samples.image is None or samples.gt_label is None: + msg = "Samples must have image and label data" + raise ValueError(msg) + + fig, axes = plt.subplots(1, 4, figsize=(15, 4)) + fig.suptitle(title) + + for i in range(4): + img = samples.image[i].permute(1, 2, 0).numpy() + axes[i].imshow(img) + axes[i].axis("off") + if hasattr(samples, "gt_mask") and samples.gt_mask is not None: + mask = samples.gt_mask[i].squeeze().numpy() + axes[i].imshow(mask, alpha=0.3, cmap="Reds") + label = "Normal" if samples.gt_label[i] == 0 else "Anomaly" + axes[i].set_title(label) + + plt.tight_layout() + plt.show() + + +# Show training samples (normal only) +show_samples(train_samples, "Training Samples (Normal)") + +# Show test samples (mix of normal and anomalous) +show_samples(test_samples, "Test Samples (Normal + Anomalous)") + +if __name__ == "__main__": + print("\nMVTec AD 2 Dataset example completed successfully!") diff --git a/examples/configs/data/mvtecad2.yaml b/examples/configs/data/mvtecad2.yaml new file mode 100644 index 0000000000..dcd8dd73a1 --- /dev/null +++ b/examples/configs/data/mvtecad2.yaml @@ -0,0 +1,9 @@ +class_path: anomalib.data.MVTecAD2 +init_args: + root: "./datasets/MVTec_AD_2" + category: "sheet_metal" + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + test_type: "public" + seed: null diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 42d8e15902..4b37cc4e36 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -49,7 +49,19 @@ # Datamodules from .datamodules.base import AnomalibDataModule from .datamodules.depth import DepthDataFormat, Folder3D, MVTec3D -from .datamodules.image import VAD, BTech, Datumaro, Folder, ImageDataFormat, Kolektor, MVTec, MVTecAD, MVTecLOCO, Visa +from .datamodules.image import ( + VAD, + BTech, + Datumaro, + Folder, + ImageDataFormat, + Kolektor, + MVTec, + MVTecAD, + MVTecAD2, + MVTecLOCO, + Visa, +) from .datamodules.video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat # Datasets @@ -165,6 +177,7 @@ def get_datamodule(config: DictConfig | ListConfig | dict) -> AnomalibDataModule "MVTec", # Include MVTec for backward compatibility "MVTecAD", "MVTecADDataset", + "MVTecAD2", "MVTecLOCO", "MVTecLOCODataset", "VAD", diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 887349b02c..81b6e2140e 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -317,6 +317,9 @@ def _create_val_split(self) -> None: This handles sampling from train/test sets and optionally creating synthetic anomalies. """ + if self.val_split_mode == ValSplitMode.FROM_DIR: + # If the validation split mode is FROM_DIR, we don't need to create a validation set + return if self.val_split_mode == ValSplitMode.FROM_TRAIN: # randomly sample from train set self.train_data, self.val_data = random_split( diff --git a/src/anomalib/data/datamodules/image/__init__.py b/src/anomalib/data/datamodules/image/__init__.py index fb0d6a07e1..f556c0bd9b 100644 --- a/src/anomalib/data/datamodules/image/__init__.py +++ b/src/anomalib/data/datamodules/image/__init__.py @@ -8,6 +8,7 @@ - ``Folder``: Custom folder structure with normal/abnormal images - ``Kolektor``: Kolektor Surface-Defect Dataset - ``MVTecAD``: MVTec Anomaly Detection Dataset +- ``MVTecAD2``: MVTec Anomaly Detection Dataset 2 - ``MVTecLOCO``: MVTec LOCO Dataset with logical and structural anomalies - ``VAD``: Valeo Anomaly Detection Dataset - ``Visa``: Visual Anomaly Dataset @@ -33,6 +34,7 @@ from .kolektor import Kolektor from .mvtec_loco import MVTecLOCO from .mvtecad import MVTec, MVTecAD +from .mvtecad2 import MVTecAD2 from .vad import VAD from .visa import Visa @@ -48,6 +50,7 @@ class ImageDataFormat(str, Enum): - ``FOLDER_3D``: Custom folder structure for 3D images - ``KOLEKTOR``: Kolektor Surface-Defect Dataset - ``MVTEC_AD``: MVTec AD Dataset + - ``MVTEC_AD_2``: MVTec AD 2 Dataset - ``MVTEC_3D``: MVTec 3D AD Dataset - ``MVTEC_LOCO``: MVTec LOCO Dataset - ``VAD``: Valeo Anomaly Detection Dataset @@ -60,6 +63,7 @@ class ImageDataFormat(str, Enum): FOLDER_3D = "folder_3d" KOLEKTOR = "kolektor" MVTEC_AD = "mvtecad" + MVTEC_AD_2 = "mvtecad2" MVTEC_3D = "mvtec_3d" MVTEC_LOCO = "mvtec_loco" VAD = "vad" @@ -71,8 +75,9 @@ class ImageDataFormat(str, Enum): "Datumaro", "Folder", "Kolektor", + "MVTec", # Include MVTec for backward compatibility "MVTecAD", - "MVTec", # Include both for backward compatibility + "MVTecAD2", "MVTecLOCO", "VAD", "Visa", diff --git a/src/anomalib/data/datamodules/image/mvtecad2.py b/src/anomalib/data/datamodules/image/mvtecad2.py new file mode 100644 index 0000000000..ba7934207a --- /dev/null +++ b/src/anomalib/data/datamodules/image/mvtecad2.py @@ -0,0 +1,227 @@ +"""MVTec AD 2 Lightning Data Module. + +This module implements a PyTorch Lightning DataModule for the MVTec AD 2 dataset. +The module handles downloading, loading, and preprocessing of the dataset for +training and evaluation. + +The dataset provides three different test sets: + - Public test set (test_public/): Contains both normal and anomalous samples with ground truth masks + for facilitating local testing and initial performance estimation + - Private test set (test_private/): Official unseen test set without ground truth + for entering the leaderboard + - Private mixed test set (test_private_mixed/): Contains unseen test samples captured + under seen and unseen lighting conditions (mixed randomly) without ground truth + +The public test set is meant for local evaluation, while the private test sets +are the official test sets for entering the leaderboard on the evaluation server +(https://benchmark.mvtec.com/). + +License: + MVTec AD 2 dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Lars Heckler-Kram, Jan-Hendrik Neudeck, Ulla Scheler, Rebecca König, Carsten Steger: + The MVTec AD 2 Dataset: Advanced Scenarios for Unsupervised Anomaly Detection. + arXiv preprint, 2024 (to appear). +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +from lightning.pytorch.utilities.types import EVAL_DATALOADERS +from torch.utils.data import DataLoader +from torchvision.transforms.v2 import Transform + +from anomalib.data.datamodules.base.image import AnomalibDataModule +from anomalib.data.datasets.image import MVTecAD2Dataset +from anomalib.data.datasets.image.mvtecad2 import TestType +from anomalib.data.utils import Split + + +class MVTecAD2(AnomalibDataModule): + """MVTec AD 2 Lightning Data Module. + + Args: + root (str | Path): Path to the dataset root directory. + Defaults to ``"./datasets/MVTec_AD_2"``. + category (str): Name of the MVTec AD 2 category to load. + Defaults to ``"sheet_metal"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Validation and test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers for data loading. + Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply to the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. + test_type (str | TestType): Type of test set to use: + - ``"public"``: Test set with ground truth for local evaluation and initial + performance estimation + - ``"private"``: Official test set without ground truth for leaderboard submission + - ``"private_mixed"``: Official test set with mixed lighting conditions (seen and + unseen, randomly mixed) for leaderboard submission + Defaults to ``TestType.PUBLIC``. + seed (int | None, optional): Random seed for reproducibility. + Defaults to ``None``. + + Example: + >>> from anomalib.data import MVTecAD2 + >>> datamodule = MVTecAD2( + ... root="./datasets/MVTec_AD_2", + ... category="sheet_metal", + ... train_batch_size=32, + ... eval_batch_size=32, + ... num_workers=8, + ... ) + + To use private test set: + >>> datamodule = MVTecAD2( + ... root="./datasets/MVTec_AD_2", + ... category="sheet_metal", + ... test_type="private", + ... ) + + Access different test sets: + >>> datamodule.setup() + >>> public_loader = datamodule.test_dataloader() # returns loader based on test_type + >>> private_loader = datamodule.test_dataloader(test_type="private") + >>> mixed_loader = datamodule.test_dataloader(test_type="private_mixed") + """ + + def __init__( + self, + root: str | Path = "./datasets/MVTec_AD_2", + category: str = "sheet_metal", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, + test_type: str | TestType = TestType.PUBLIC, + seed: int | None = None, + ) -> None: + """Initialize MVTec AD 2 datamodule.""" + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, + seed=seed, + ) + + self.root = Path(root) + self.category = category + self.test_type = TestType(test_type) if isinstance(test_type, str) else test_type + + def prepare_data(self) -> None: + """Prepare the dataset. + + MVTec AD 2 dataset needs to be downloaded manually from MVTec AD 2 website. + """ + # NOTE: For now, users need to manually download the dataset. + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform train/validation/test split. + + Args: + _stage: str | None: Optional argument for compatibility with pytorch + lightning. Defaults to None. + """ + self.train_data = MVTecAD2Dataset( + root=self.root, + category=self.category, + split=Split.TRAIN, + augmentations=self.train_augmentations, + ) + + # MVTec AD 2 has a dedicated validation set + self.val_data = MVTecAD2Dataset( + root=self.root, + category=self.category, + split=Split.VAL, + augmentations=self.val_augmentations, + ) + + # Create datasets for all test types + self.test_public_data = MVTecAD2Dataset( + root=self.root, + category=self.category, + split=Split.TEST, + test_type=TestType.PUBLIC, + augmentations=self.test_augmentations, + ) + + self.test_private_data = MVTecAD2Dataset( + root=self.root, + category=self.category, + split=Split.TEST, + test_type=TestType.PRIVATE, + augmentations=self.test_augmentations, + ) + + self.test_private_mixed_data = MVTecAD2Dataset( + root=self.root, + category=self.category, + split=Split.TEST, + test_type=TestType.PRIVATE_MIXED, + augmentations=self.test_augmentations, + ) + + # Always set test_data to public test set for standard evaluation + self.test_data = self.test_public_data + + def test_dataloader(self, test_type: str | TestType | None = None) -> EVAL_DATALOADERS: + """Get test dataloader for the specified test type. + + Args: + test_type (str | TestType | None, optional): Type of test set to use: + - ``"public"``: Test set with ground truth for local evaluation + - ``"private"``: Official test set without ground truth for leaderboard + - ``"private_mixed"``: Official test set with mixed lighting conditions + If None, uses the test_type specified in __init__. + Defaults to None. + + Example: + >>> datamodule.setup() + >>> public_loader = datamodule.test_dataloader() # returns loader based on test_type + >>> private_loader = datamodule.test_dataloader(test_type="private") + >>> mixed_loader = datamodule.test_dataloader(test_type="private_mixed") + + Returns: + EVAL_DATALOADERS: Test dataloader for the specified test type. + """ + test_type = test_type or self.test_type + test_type = TestType(test_type) if isinstance(test_type, str) else test_type + + if test_type == TestType.PUBLIC: + dataset = self.test_public_data + elif test_type == TestType.PRIVATE: + dataset = self.test_private_data + elif test_type == TestType.PRIVATE_MIXED: + dataset = self.test_private_mixed_data + else: + msg = f"Invalid test type: {test_type}. Must be one of {TestType.__members__.keys()}." + raise ValueError(msg) + + return DataLoader( + dataset=dataset, + shuffle=False, + batch_size=self.eval_batch_size, + num_workers=self.num_workers, + collate_fn=dataset.collate_fn, + ) diff --git a/src/anomalib/data/datasets/base/image.py b/src/anomalib/data/datasets/base/image.py index e4bcc526b5..b425f32c1a 100644 --- a/src/anomalib/data/datasets/base/image.py +++ b/src/anomalib/data/datasets/base/image.py @@ -254,9 +254,6 @@ def __getitem__(self, index: int) -> DatasetItem: Returns: DatasetItem: Dataset item containing image and ground truth (if available). - Raises: - ValueError: If task type is unknown. - Example: >>> dataset = AnomalibDataset() >>> item = dataset[0] @@ -267,29 +264,44 @@ def __getitem__(self, index: int) -> DatasetItem: mask_path = self.samples.iloc[index].mask_path label_index = self.samples.iloc[index].label_index + # Read the image image = read_image(image_path, as_tensor=True) - item = {"image_path": image_path, "gt_label": label_index} - - if self.task == TaskType.CLASSIFICATION: - item["image"] = self.augmentations(image) if self.augmentations else image - elif self.task == TaskType.SEGMENTATION: - # Only Anomalous (1) images have masks in anomaly datasets - # Therefore, create empty mask for Normal (0) images. - mask = ( - Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) - if label_index == LabelName.NORMAL - else read_mask(mask_path, as_tensor=True) - ) - item["image"], item["gt_mask"] = self.augmentations(image, mask) if self.augmentations else (image, mask) - - else: - msg = f"Unknown task type: {self.task}" - raise ValueError(msg) + # Initialize mask as None + gt_mask = None + + # Process based on task type + if self.task == TaskType.SEGMENTATION: + if label_index == LabelName.NORMAL: + # Create zero mask for normal samples + gt_mask = Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) + elif label_index == LabelName.ABNORMAL: + # Read mask for anomalous samples + gt_mask = read_mask(mask_path, as_tensor=True) + # For UNKNOWN, gt_mask remains None + + # Apply augmentations if available + if self.augmentations: + if self.task == TaskType.CLASSIFICATION: + image = self.augmentations(image) + elif self.task == TaskType.SEGMENTATION: + # For augmentations that require both image and mask: + # - Use a temporary zero mask for UNKNOWN samples + # - But preserve the final gt_mask as None for UNKNOWN + temp_mask = gt_mask if gt_mask is not None else Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) + image, augmented_mask = self.augmentations(image, temp_mask) + # Only update gt_mask if it wasn't None before augmentations + if gt_mask is not None: + gt_mask = augmented_mask + + # Create gt_label tensor (None for UNKNOWN) + gt_label = None if label_index == LabelName.UNKNOWN else torch.tensor(label_index) + + # Return the dataset item return ImageItem( - image=item["image"], - gt_mask=item.get("gt_mask"), - gt_label=int(label_index), + image=image, + gt_mask=gt_mask, + gt_label=gt_label, image_path=image_path, mask_path=mask_path, ) diff --git a/src/anomalib/data/datasets/image/__init__.py b/src/anomalib/data/datasets/image/__init__.py index 49d176d1b0..ff49768982 100644 --- a/src/anomalib/data/datasets/image/__init__.py +++ b/src/anomalib/data/datasets/image/__init__.py @@ -30,6 +30,7 @@ from .kolektor import KolektorDataset from .mvtec_loco import MVTecLOCODataset from .mvtecad import MVTecADDataset, MVTecDataset +from .mvtecad2 import MVTecAD2Dataset from .vad import VADDataset from .visa import VisaDataset @@ -38,9 +39,10 @@ "DatumaroDataset", "FolderDataset", "KolektorDataset", + "MVTecDataset", "MVTecADDataset", + "MVTecAD2Dataset", "MVTecLOCODataset", "VADDataset", "VisaDataset", - "MVTecDataset", ] diff --git a/src/anomalib/data/datasets/image/mvtecad.py b/src/anomalib/data/datasets/image/mvtecad.py index 192f8260e6..7f303716c9 100644 --- a/src/anomalib/data/datasets/image/mvtecad.py +++ b/src/anomalib/data/datasets/image/mvtecad.py @@ -190,7 +190,7 @@ def make_mvtec_ad_dataset( ) # assign mask paths to anomalous test images - samples["mask_path"] = "" + samples["mask_path"] = None samples.loc[ (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), "mask_path", diff --git a/src/anomalib/data/datasets/image/mvtecad2.py b/src/anomalib/data/datasets/image/mvtecad2.py new file mode 100644 index 0000000000..1070225703 --- /dev/null +++ b/src/anomalib/data/datasets/image/mvtecad2.py @@ -0,0 +1,293 @@ +"""MVTec AD 2 Dataset. + +This module provides PyTorch Dataset implementation for the MVTec AD 2 dataset. +The dataset contains 8 categories of industrial objects with both normal and +anomalous samples. Each category includes RGB images and pixel-level ground truth +masks for anomaly segmentation. + +The dataset provides three different test sets: + - Public test set (test_public/): Contains both normal and anomalous samples with ground truth masks + - Private test set (test_private/): Contains unseen test samples without ground truth + - Private mixed test set (test_private_mixed/): Contains unseen test samples + with mixed anomalies without ground truth + +The public test set is used for standard evaluation, while the private test sets +are used for real-world evaluation scenarios where ground truth is not available. + +License: + MVTec AD 2 dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0) https://creativecommons.org/licenses/by-nc-sa/4.0/ + +Reference: + Lars Heckler-Kram, Jan-Hendrik Neudeck, Ulla Scheler, Rebecca König, Carsten Steger: + The MVTec AD 2 Dataset: Advanced Scenarios for Unsupervised Anomaly Detection. + arXiv preprint, 2024 (to appear). +""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from enum import Enum +from pathlib import Path + +from pandas import DataFrame +from torchvision.transforms.v2 import Transform + +from anomalib.data.datasets.base.image import AnomalibDataset +from anomalib.data.errors import MisMatchError +from anomalib.data.utils import Split, validate_path + + +class TestType(str, Enum): + """Type of test set to use. + + The MVTec AD 2 dataset provides three different test sets: + - PUBLIC: Test set with ground truth masks for facilitating local testing and initial performance estimation + - PRIVATE: Official unseen test set without ground truth for entering the leaderboard + - PRIVATE_MIXED: Official unseen test set captured under seen and unseen lighting conditions (mixed randomly) + + Official evaluation server: https://benchmark.mvtec.com/ + """ + + PUBLIC = "public" # Test set with ground truth for local evaluation + PRIVATE = "private" # Official private test set without ground truth + PRIVATE_MIXED = "private_mixed" # Official private test set with mixed lighting conditions + + +IMG_EXTENSIONS = (".png", ".PNG") +CATEGORIES = ( + "can", + "fabric", + "fruit_jelly", + "rice", + "sheet_metal", + "vial", + "wallplugs", + "walnuts", +) + + +class MVTecAD2Dataset(AnomalibDataset): + """MVTec AD 2 dataset class. + + Args: + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/MVTec_AD_2"``. + category (str): Category name, e.g. ``"sheet_metal"``. + Defaults to ``"sheet_metal"``. + augmentations (Transform, optional): Augmentations that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Dataset split - usually ``Split.TRAIN``, ``Split.VAL``, + or ``Split.TEST``. Defaults to ``None``. + test_type (str | TestType): Type of test set to use - only used when split is ``Split.TEST``: + - ``"public"``: Test set with ground truth for local evaluation and initial performance estimation + - ``"private"``: Official test set without ground truth for leaderboard submission + - ``"private_mixed"``: Official test set with mixed lighting conditions (seen and unseen lighting) + Defaults to ``TestType.PUBLIC``. + + Example: + Create training dataset:: + + >>> from pathlib import Path + >>> dataset = MVTecAD2Dataset( + ... root=Path("./datasets/MVTec_AD_2"), + ... category="sheet_metal", + ... split="train" + ... ) + + Create validation dataset:: + + >>> val_dataset = MVTecAD2Dataset( + ... root=Path("./datasets/MVTec_AD_2"), + ... category="sheet_metal", + ... split="val" + ... ) + + Create test datasets:: + + >>> # Public test set (with ground truth) + >>> test_dataset = MVTecAD2Dataset( + ... root=Path("./datasets/MVTec_AD_2"), + ... category="sheet_metal", + ... split="test", + ... test_type="public" + ... ) + + >>> # Private test set (without ground truth) + >>> private_dataset = MVTecAD2Dataset( + ... root=Path("./datasets/MVTec_AD_2"), + ... category="sheet_metal", + ... split="test", + ... test_type="private" + ... ) + + >>> # Private mixed test set (without ground truth) + >>> mixed_dataset = MVTecAD2Dataset( + ... root=Path("./datasets/MVTec_AD_2"), + ... category="sheet_metal", + ... split="test", + ... test_type="private_mixed" + ... ) + + Notes: + - The public test set contains both normal and anomalous samples with ground truth masks + - Private test sets (private and private_mixed) contain samples without ground truth + - Private test samples are labeled as "unknown" with label_index=-1 + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec_AD_2", + category: str = "sheet_metal", + augmentations: Transform | None = None, + split: str | Split | None = None, + test_type: TestType | str = TestType.PUBLIC, + ) -> None: + super().__init__(augmentations=augmentations) + + self.root_category = Path(root) / Path(category) + self.split = split + self.test_type = TestType(test_type) if isinstance(test_type, str) else test_type + self.samples = make_mvtec2_dataset( + self.root_category, + split=self.split, + test_type=self.test_type, + extensions=IMG_EXTENSIONS, + ) + + +def make_mvtec2_dataset( + root: str | Path, + split: str | Split | None = None, + test_type: TestType = TestType.PUBLIC, + extensions: Sequence[str] | None = None, +) -> DataFrame: + """Create MVTec AD 2 samples by parsing the data directory structure. + + The files are expected to follow this structure:: + + root/ + ├── test_private/ + ├── test_private_mixed/ + ├── test_public/ + │ ├── bad/ + │ ├── good/ + │ └── ground_truth/ + │ └── bad/ + ├── train/ + │ └── good/ + └── validation/ + └── good/ + + Args: + root (str | Path): Path to the dataset root directory + split (str | Split | None, optional): Dataset split (train, val, test). Defaults to None. + test_type (TestType, optional): Type of test set to use for testing: + - PUBLIC: Test set with ground truth (for local evaluation) + - PRIVATE: Official test set without ground truth (for leaderboard) + - PRIVATE_MIXED: Official test set with mixed lighting conditions (for leaderboard) + Defaults to TestType.PUBLIC. + extensions (Sequence[str] | None, optional): Image extensions to include. Defaults to None. + + Returns: + DataFrame: Dataset samples with columns: + - path: Base path to dataset + - split: Dataset split (train/test) + - label: Class label + - image_path: Path to image file + - mask_path: Path to mask file (if available) + - label_index: Numeric label (0=normal, 1=abnormal) + + Example: + >>> root = Path("./datasets/MVTec_AD_2/sheet_metal") + >>> samples = make_mvtec2_dataset(root, split="train") + >>> samples.head() + path split label image_path mask_path label_index + 0 datasets/MVTec_AD_2 train good [...]/good/105.png 0 + 1 datasets/MVTec_AD_2 train good [...]/good/017.png 0 + + Raises: + RuntimeError: If no valid images are found + MisMatchError: If anomalous images and masks don't match + """ + if extensions is None: + extensions = IMG_EXTENSIONS + + root = validate_path(root) + samples_list: list[tuple[str, str, str, str, str | None, int]] = [] + + # Get all image files + image_files = [f for f in root.glob("**/*") if f.suffix in extensions] + if not image_files: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + # Process training samples (only normal) + train_path = root / "train" / "good" + if train_path.exists(): + train_samples = [ + (str(root), "train", "good", str(f), None, 0) for f in train_path.glob(f"*[{''.join(extensions)}]") + ] + samples_list.extend(train_samples) + + # Process validation samples (only normal) + val_path = root / "validation" / "good" + if val_path.exists(): + val_samples = [(str(root), "val", "good", str(f), None, 0) for f in val_path.glob(f"*[{''.join(extensions)}]")] + samples_list.extend(val_samples) + + # Process test samples based on test_type + if test_type == TestType.PUBLIC: + test_path = root / "test_public" + if test_path.exists(): + # Normal test samples + test_normal_path = test_path / "good" + test_normal_samples = [ + (str(root), "test", "good", str(f), None, 0) for f in test_normal_path.glob(f"*[{''.join(extensions)}]") + ] + samples_list.extend(test_normal_samples) + + # Abnormal test samples + test_abnormal_path = test_path / "bad" + if test_abnormal_path.exists(): + for image_path in test_abnormal_path.glob(f"*[{''.join(extensions)}]"): + # Add _mask suffix to the filename + mask_name = image_path.stem + "_mask" + image_path.suffix + mask_path = root / "test_public" / "ground_truth" / "bad" / mask_name + if not mask_path.exists(): + msg = f"Missing mask for anomalous image: {image_path}" + raise MisMatchError(msg) + samples_list.append( + (str(root), "test", "bad", str(image_path), str(mask_path), 1), + ) + elif test_type == TestType.PRIVATE: + test_path = root / "test_private" + if test_path.exists(): + # All samples in private test set are treated as unknown + test_samples = [ + (str(root), "test", "unknown", str(f), None, -1) for f in test_path.glob(f"*[{''.join(extensions)}]") + ] + samples_list.extend(test_samples) + elif test_type == TestType.PRIVATE_MIXED: + test_path = root / "test_private_mixed" + if test_path.exists(): + # All samples in private mixed test set are treated as unknown + test_samples = [ + (str(root), "test", "unknown", str(f), None, -1) for f in test_path.glob(f"*[{''.join(extensions)}]") + ] + samples_list.extend(test_samples) + + samples = DataFrame( + samples_list, + columns=["path", "split", "label", "image_path", "mask_path", "label_index"], + ) + + # Filter by split if specified + if split: + split = Split(split) if isinstance(split, str) else split + samples = samples[samples.split == split.value] + + samples.attrs["task"] = "segmentation" + return samples diff --git a/src/anomalib/data/utils/label.py b/src/anomalib/data/utils/label.py index ce12b8bfb2..f4da21da8d 100644 --- a/src/anomalib/data/utils/label.py +++ b/src/anomalib/data/utils/label.py @@ -3,6 +3,7 @@ This module defines an enumeration class for labeling data in anomaly detection tasks. The labels are represented as integers, where: +- ``UNKNOWN`` (-1): Represents samples with unknown/undefined labels - ``NORMAL`` (0): Represents normal/good samples - ``ABNORMAL`` (1): Represents anomalous/defective samples @@ -14,6 +15,9 @@ >>> label = LabelName.ABNORMAL >>> label.value 1 + >>> label = LabelName.UNKNOWN + >>> label.value + -1 """ # Copyright (C) 2023-2024 Intel Corporation @@ -30,9 +34,11 @@ class LabelName(int, Enum): names and their corresponding integer values. Attributes: + UNKNOWN (int): Label value -1, representing samples with unknown/undefined labels NORMAL (int): Label value 0, representing normal/good samples ABNORMAL (int): Label value 1, representing anomalous/defective samples """ + UNKNOWN = -1 NORMAL = 0 ABNORMAL = 1 diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index e2d9b5a6b3..cb5335b63f 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -80,6 +80,7 @@ class ValSplitMode(str, Enum): FROM_TRAIN: Split from training set FROM_TEST: Split from test set SYNTHETIC: Synthetic validation split + FROM_DIR: Use dedicated validation directory (for datasets that have one) """ NONE = "none" @@ -87,6 +88,7 @@ class ValSplitMode(str, Enum): FROM_TRAIN = "from_train" FROM_TEST = "from_test" SYNTHETIC = "synthetic" + FROM_DIR = "from_dir" def concatenate_datasets( diff --git a/tests/helpers/data.py b/tests/helpers/data.py index eafcf723a1..aa9700eeda 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -503,6 +503,72 @@ def _generate_dummy_visa_dataset(self) -> None: self.dataset_root = self.dataset_root.parent / "visa_pytorch" self._generate_dummy_mvtecad_dataset(normal_dir="good", abnormal_dir="bad", image_extension=".jpg") + def _generate_dummy_mvtecad2_dataset( + self, + normal_dir: str = "good", + abnormal_dir: str = "bad", + image_extension: str = ".png", + mask_suffix: str = "_mask", + mask_extension: str = ".png", + ) -> None: + """Generate a dummy MVTec AD 2 dataset. + + Args: + normal_dir (str, optional): Name of the normal directory. Defaults to "good". + abnormal_dir (str, optional): Name of the abnormal directory. Defaults to "bad". + image_extension (str, optional): Extension of the image files. Defaults to ".png". + mask_suffix (str, optional): Suffix to append to mask filenames. Defaults to "_mask". + mask_extension (str, optional): Extension of the mask files. Defaults to ".png". + """ + # MVTec AD 2 has multiple subcategories within the dataset + dataset_category = "dummy" + category_root = self.dataset_root / dataset_category + + # Create train directory with normal images + train_path = category_root / "train" / normal_dir + for i in range(self.num_train): + image_path = train_path / f"{i:03d}_regular{image_extension}" + self.image_generator.generate_image(label=LabelName.NORMAL, image_filename=image_path) + + # Create validation directory with normal images + val_path = category_root / "validation" / normal_dir + for i in range(self.num_test): + image_path = val_path / f"{i:03d}_regular{image_extension}" + self.image_generator.generate_image(label=LabelName.NORMAL, image_filename=image_path) + + # Create public test directory with normal and abnormal images + test_public_path = category_root / "test_public" + + # Normal test images + test_normal_path = test_public_path / normal_dir + for i in range(self.num_test): + image_path = test_normal_path / f"{i:03d}_regular{image_extension}" + self.image_generator.generate_image(label=LabelName.NORMAL, image_filename=image_path) + + # Abnormal test images with masks + test_abnormal_path = test_public_path / abnormal_dir + test_mask_path = test_public_path / "ground_truth" / abnormal_dir + for i in range(self.num_test): + image_path = test_abnormal_path / f"{i:03d}_regular{image_extension}" + mask_path = test_mask_path / f"{i:03d}_regular{mask_suffix}{mask_extension}" + self.image_generator.generate_image( + label=LabelName.ABNORMAL, + image_filename=image_path, + mask_filename=mask_path, + ) + + # Create private test directory with unknown images + test_private_path = category_root / "test_private" + for i in range(self.num_test): + image_path = test_private_path / f"{i:03d}_regular{image_extension}" + self.image_generator.generate_image(label=LabelName.NORMAL, image_filename=image_path) + + # Create private mixed test directory with unknown images + test_private_mixed_path = category_root / "test_private_mixed" + for i in range(self.num_test): + image_path = test_private_mixed_path / f"{i:03d}_regular{image_extension}" + self.image_generator.generate_image(label=LabelName.NORMAL, image_filename=image_path) + def _generate_dummy_vad_dataset( self, normal_dir: str = "good", diff --git a/tests/unit/data/datamodule/image/test_mvtecad2.py b/tests/unit/data/datamodule/image/test_mvtecad2.py new file mode 100644 index 0000000000..cc521c43e5 --- /dev/null +++ b/tests/unit/data/datamodule/image/test_mvtecad2.py @@ -0,0 +1,63 @@ +"""Unit tests - MVTec AD 2 Datamodule.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest +from torchvision.transforms.v2 import Resize + +from anomalib.data import MVTecAD2 +from anomalib.data.datasets.image.mvtecad2 import TestType +from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule + + +class TestMVTecAD2(_TestAnomalibImageDatamodule): + """MVTec AD 2 Datamodule Unit Tests.""" + + @pytest.fixture() + @staticmethod + def datamodule(dataset_path: Path) -> MVTecAD2: + """Create and return a MVTec AD 2 datamodule.""" + _datamodule = MVTecAD2( + root=dataset_path / "mvtecad2", + category="dummy", + train_batch_size=4, + eval_batch_size=4, + augmentations=Resize((256, 256)), + ) + _datamodule.setup() + + return _datamodule + + @pytest.fixture() + @staticmethod + def fxt_data_config_path() -> str: + """Return the path to the test data config.""" + return "examples/configs/data/mvtecad2.yaml" + + @staticmethod + def test_test_types(datamodule: MVTecAD2) -> None: + """Test that the datamodule can handle different test types.""" + # Test public test set + public_loader = datamodule.test_dataloader(test_type=TestType.PUBLIC) + assert public_loader is not None + batch = next(iter(public_loader)) + assert batch.image.shape == (4, 3, 256, 256) + + # Test private test set + private_loader = datamodule.test_dataloader(test_type=TestType.PRIVATE) + assert private_loader is not None + batch = next(iter(private_loader)) + assert batch.image.shape == (4, 3, 256, 256) + + # Test private mixed test set + mixed_loader = datamodule.test_dataloader(test_type=TestType.PRIVATE_MIXED) + assert mixed_loader is not None + batch = next(iter(mixed_loader)) + assert batch.image.shape == (4, 3, 256, 256) + + # Test invalid test type + with pytest.raises(ValueError, match="'invalid' is not a valid TestType"): + datamodule.test_dataloader(test_type="invalid")