Skip to content

Commit a7b104b

Browse files
committed
Add: posibility that if is classification to not add the mask_dir
Signed-off-by: Bepitic <[email protected]>
1 parent 865935f commit a7b104b

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

src/anomalib/data/video/folder_video.py

+29-28
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
def make_folder_video_dataset(
3636
root: str | Path,
3737
normal_dir: str | Path = "",
38-
mask_dir: str | Path = "",
38+
mask_dir: str | Path | None = "",
3939
test_dir: str | Path = "",
4040
split: str | Split | None = None,
4141
) -> DataFrame:
@@ -122,24 +122,27 @@ def _extract_samples(root: Path, path: Path) -> list:
122122
root = validate_path(root)
123123
normal_dir = validate_path(root / normal_dir)
124124
test_dir = validate_path(root / test_dir)
125-
mask_dir = validate_path(root / mask_dir)
125+
126+
samples_list_labels = []
127+
if mask_dir is not None:
128+
mask_dir = validate_path(root / mask_dir)
129+
samples_list_labels.extend(
130+
[
131+
filename.parts[-1]
132+
for filename in sorted(mask_dir.glob("./*"))
133+
if (
134+
_contains_files(path=filename, extensions=FOLDER_IMAGE_EXTENSIONS)
135+
and not filename.name.startswith(".")
136+
)
137+
or filename.suffix in [".npy", ".pt"]
138+
],
139+
)
140+
126141
samples_list = []
127142
samples_list.extend(_extract_samples(root, normal_dir))
128143

129144
samples_list.extend(_extract_samples(root, test_dir))
130145

131-
samples_list_labels = []
132-
samples_list_labels.extend(
133-
[
134-
filename.parts[-1]
135-
for filename in sorted(mask_dir.glob("./*"))
136-
if (
137-
_contains_files(path=filename, extensions=FOLDER_IMAGE_EXTENSIONS) and not filename.name.startswith(".")
138-
)
139-
or filename.suffix in [".npy", ".pt"]
140-
],
141-
)
142-
143146
samples = DataFrame(samples_list, columns=["root", "folder", "image_path"])
144147

145148
# Remove DS_Store
@@ -150,10 +153,12 @@ def _extract_samples(root: Path, path: Path) -> list:
150153
samples.loc[samples.folder == normal_dir.parts[-1], "split"] = "train"
151154
samples.loc[samples.folder == test_dir.parts[-1], "split"] = "test"
152155
samples_list_labels = [str(item) for item in samples_list_labels if ".DS_Store" not in str(item)]
153-
samples.loc[samples.folder == test_dir.parts[-1], "mask_path"] = samples_list_labels
154-
if samples_list_labels == []:
156+
157+
if mask_dir is not None:
158+
samples.loc[samples.folder == test_dir.parts[-1], "mask_path"] = samples_list_labels
159+
samples["mask_path"] = str(mask_dir) + "/" + samples.mask_path.astype(str)
160+
else:
155161
samples.loc[samples.folder == test_dir.parts[-1], "mask_path"] = ""
156-
samples["mask_path"] = str(mask_dir) + "/" + samples.mask_path.astype(str)
157162
samples.loc[samples.folder == normal_dir.parts[-1], "mask_path"] = ""
158163

159164
if split:
@@ -239,7 +244,7 @@ class FolderDataset(AnomalibVideoDataset):
239244
root (Path | str): Path to the dataset.
240245
normal_dir (Path | str): Path to the training videos of the dataset (.avi / .mp4 / imgages as frames).
241246
test_dir (Path | str): Path to the testing videos of the dataset.
242-
mask_dir (Path | str): Path to the masks for the training videos of the dataset (.npy/.pt/ images)
247+
mask_dir (Path | str | None): Path to the masks for the training videos of the dataset (.npy/.pt/ images)
243248
clip_length_in_frames (int, optional): Number of video frames in each clip.
244249
frames_between_clips (int, optional): Number of frames between each consecutive video clip.
245250
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval.
@@ -252,7 +257,7 @@ def __init__(
252257
task: TaskType,
253258
split: Split,
254259
root: Path | str,
255-
mask_dir: Path | str = "test_labels",
260+
mask_dir: Path | str | None = None,
256261
normal_dir: Path | str = "train_dir",
257262
test_dir: Path | str = "test_dir",
258263
clip_length_in_frames: int = 2,
@@ -307,11 +312,11 @@ class FolderVideo(AnomalibVideoDataModule):
307312
root (Path | str): Path to the root of the dataset
308313
normal_dir (Path | str): Path to the training videos of the dataset (.avi / .mp4 / imgages as frames).
309314
test_dir (Path | str): Path to the testing videos of the dataset.
310-
mask_dir (Path | str): Path to the masks for the training videos of the dataset (.npy/.pt/ images)
315+
mask_dir (Path | str | None): Path to the masks for the training videos of the dataset (.npy/.pt/ images)
311316
clip_length_in_frames (int, optional): Number of video frames in each clip.
312317
frames_between_clips (int, optional): Number of frames between each consecutive video clip.
313318
target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval
314-
task (TaskType): Task type, 'classification', 'detection' or 'segmentation'
319+
task (TaskType | str): Task type, 'classification', 'detection' or 'segmentation'
315320
image_size (tuple[int, int], optional): Size to which input images should be resized.
316321
Defaults to ``None``.
317322
transform (Transform, optional): Transforms that should be applied to the input images.
@@ -331,13 +336,13 @@ class FolderVideo(AnomalibVideoDataModule):
331336
def __init__(
332337
self,
333338
root: Path | str,
334-
mask_dir: Path | str = "test_labels",
339+
mask_dir: Path | str | None = None,
335340
normal_dir: Path | str = "train_vid",
336341
test_dir: Path | str = "test_vid",
337342
clip_length_in_frames: int = 2,
338343
frames_between_clips: int = 1,
339344
target_frame: VideoTargetFrame = VideoTargetFrame.LAST,
340-
task: TaskType = TaskType.SEGMENTATION,
345+
task: TaskType | str = TaskType.SEGMENTATION,
341346
image_size: tuple[int, int] | None = None,
342347
transform: Transform | None = None,
343348
train_transform: Transform | None = None,
@@ -366,11 +371,7 @@ def __init__(
366371
self.root = Path(root)
367372
self.normal_dir = Path(normal_dir)
368373
self.test_dir = Path(test_dir)
369-
if mask_dir is not None:
370-
self.mask_dir = Path(mask_dir)
371-
else:
372-
self.mask_dir = Path()
373-
374+
self.mask_dir = mask_dir if mask_dir is None else Path(mask_dir)
374375
self.clip_length_in_frames = clip_length_in_frames
375376
self.frames_between_clips = frames_between_clips
376377
self.target_frame = VideoTargetFrame(target_frame)

0 commit comments

Comments
 (0)