diff --git a/ArtExtract_Mingchun/retrival/searching_tool.py b/ArtExtract_Mingchun/retrival/searching_tool.py index f8ac3676..0cab9b8f 100644 --- a/ArtExtract_Mingchun/retrival/searching_tool.py +++ b/ArtExtract_Mingchun/retrival/searching_tool.py @@ -38,6 +38,44 @@ def load_index(path: str, metric: Metric, kind: str, dim: int, meta: Optional[di return IndexBundle(index=index, metric=metric, kind=kind, dim=dim, meta=meta or {}) +# ============ Change 1: normalize() + unnormalized warning ============ +def normalize(X: np.ndarray) -> np.ndarray: + """ + L2-normalize row vectors to unit length. + + Required before adding vectors to an ip (inner-product) index if you want + cosine similarity semantics. Without this, metric='ip' computes raw dot + product — NOT cosine — and silently returns wrong similarity rankings. + + Args: + X: shape (N, D) or (D,) — promoted to 2D internally. + + Returns: + L2-normalized float32 array, same shape as input. + + Example: + X = normalize(X) + bundle = build_index_flat(X, metric="ip") # now correctly cosine + """ + X = np.atleast_2d(X).astype("float32") + norms = np.linalg.norm(X, axis=1, keepdims=True) + return X / (norms + 1e-8) + + +def _warn_if_unnormalized(X: np.ndarray, metric: Metric, context: str) -> None: + """Warn if metric='ip' but vectors are not unit-length. Samples up to 64 rows.""" + if metric != "ip": + return + sample = X[:min(64, len(X))] + norms = np.linalg.norm(sample, axis=1) + if not np.allclose(norms, 1.0, atol=1e-3): + print( + f"[searching_tool WARNING] {context}: metric='ip' but vectors are not " + f"L2-normalized (mean norm={norms.mean():.4f}). " + f"Call normalize(X) first for correct cosine similarity." + ) + + # ============ Building Index ============ def build_index_flat(X: np.ndarray, metric: Metric = "ip") -> IndexBundle: """ @@ -45,6 +83,7 @@ def build_index_flat(X: np.ndarray, metric: Metric = "ip") -> IndexBundle: IndexFlatL2: Euclidean distance. IndexFlatIP: Inner Product (for cosine similarity, vectors should be normalized first). """ + _warn_if_unnormalized(X, metric, "build_index_flat") N, D = X.shape if metric == "ip": index = faiss.IndexFlatIP(D) @@ -54,13 +93,15 @@ def build_index_flat(X: np.ndarray, metric: Metric = "ip") -> IndexBundle: return IndexBundle(index=index, metric=metric, kind="flat", dim=D, meta={"ntotal": index.ntotal}) -def build_index_ivfpq(X: np.ndarray, metric: Metric = "ip", nlist: Optional[int] = None,m: int = 16, nbits: int = 8, +def build_index_ivfpq(X: np.ndarray, metric: Metric = "ip", nlist: Optional[int] = None, m: int = 16, nbits: int = 8, train_samples: int = 200_000, nprobe: int = 16) -> IndexBundle: """ - Moderate scale (100K~10M):IVFPQ + Moderate scale (100K~10M): IVFPQ nlist: number of Voronoi cells (clusters); default: 4*sqrt(N), at least 64 m: number of sub-vectors (must divide D) + nprobe: cells visited at search time — can be overridden per-query in search() """ + _warn_if_unnormalized(X, metric, "build_index_ivfpq") N, D = X.shape if nlist is None: nlist = max(64, int(4 * math.sqrt(N))) @@ -86,8 +127,9 @@ def build_index_hnsw(X: np.ndarray, metric: Metric = "ip", M: int = 32, efC: int Large scale (>10M): HNSW M: number of neighbors per node (higher=M denser graph=better accuracy/slower) efC: construction parameter (higher=better accuracy/slower indexing) - efS: search parameter (higher=better accuracy/slower searching) + efS: search parameter (higher=better accuracy/slower searching) — can be overridden per-query in search() """ + _warn_if_unnormalized(X, metric, "build_index_hnsw") N, D = X.shape faiss_metric = faiss.METRIC_INNER_PRODUCT if metric == "ip" else faiss.METRIC_L2 index = faiss.IndexHNSWFlat(D, M, faiss_metric) @@ -105,14 +147,58 @@ def search( topk: int = 5, exclude_self: bool = True, exclude_indices: Optional[np.ndarray] = None, + nprobe: Optional[int] = None, + efsearch: Optional[int] = None, ) -> Tuple[np.ndarray, np.ndarray]: - """ Search the index with query vectors Q and return topk results """ + """ + Search the index with query vectors Q and return topk results. + + Args: + bundle: IndexBundle to search against. + Q: Query vectors, shape (D,) or (n_q, D). + 1D vectors are safely promoted to (1, D) internally. + topk: Number of nearest neighbors to return per query. + exclude_self: Exclude the query's own row index from results. + exclude_indices: Per-query indices to exclude, shape (n_q,) or scalar. + nprobe: IVFPQ only — overrides nprobe for this call only. + Higher = better recall, slower. Restored after search. + efsearch: HNSW only — overrides efSearch for this call only. + Higher = better recall, slower. Restored after search. + """ + # Change 4: 1D shape guard — promote to (1, D) so all downstream logic is safe + was_1d = Q.ndim == 1 + if was_1d: + Q = Q[None, :] if Q.dtype != np.float32: Q = Q.astype("float32") - n_q = Q.shape[0] - need = topk + 1 if (exclude_self or exclude_indices is not None) else topk - D_raw, I_raw = bundle.index.search(Q, need) + + # Changes 2 & 3: runtime nprobe / efSearch overrides + # Save originals and restore after search so the bundle config is unchanged. + _orig_nprobe = None + _orig_efsearch = None + + if nprobe is not None: + if bundle.kind != "ivfpq": + raise ValueError(f"nprobe override is only valid for ivfpq bundles, got '{bundle.kind}'") + _orig_nprobe = bundle.index.nprobe + bundle.index.nprobe = nprobe + + if efsearch is not None: + if bundle.kind != "hnsw": + raise ValueError(f"efsearch override is only valid for hnsw bundles, got '{bundle.kind}'") + _orig_efsearch = bundle.index.hnsw.efSearch + bundle.index.hnsw.efSearch = efsearch + + try: + need = topk + 1 if (exclude_self or exclude_indices is not None) else topk + D_raw, I_raw = bundle.index.search(Q, need) + finally: + # always restore — even if search() raises + if _orig_nprobe is not None: + bundle.index.nprobe = _orig_nprobe + if _orig_efsearch is not None: + bundle.index.hnsw.efSearch = _orig_efsearch if exclude_indices is not None: exclude_indices = np.atleast_1d(exclude_indices) @@ -126,7 +212,7 @@ def search( mask = I_raw[r] != ex D_out[r] = D_raw[r][mask][:topk] I_out[r] = I_raw[r][mask][:topk] - return D_out, I_out + return (D_out[0], I_out[0]) if was_1d else (D_out, I_out) if exclude_self: D_out = np.empty((n_q, topk), dtype=D_raw.dtype) @@ -135,9 +221,11 @@ def search( mask = I_raw[r] != r D_out[r] = D_raw[r][mask][:topk] I_out[r] = I_raw[r][mask][:topk] - return D_out, I_out + return (D_out[0], I_out[0]) if was_1d else (D_out, I_out) + + D_out, I_out = D_raw[:, :topk], I_raw[:, :topk] + return (D_out[0], I_out[0]) if was_1d else (D_out, I_out) - return D_raw[:, :topk], I_raw[:, :topk] # ============ Re-ranking ============ def rerank(Q: np.ndarray, X_cands: np.ndarray) -> np.ndarray: @@ -157,7 +245,7 @@ def rerank(Q: np.ndarray, X_cands: np.ndarray) -> np.ndarray: # ============ Evaluating ============ -def recall_k(flat_bundle: IndexBundle,ann_bundle: IndexBundle, X: np.ndarray, k: int = 10, +def recall_k(flat_bundle: IndexBundle, ann_bundle: IndexBundle, X: np.ndarray, k: int = 10, nsamp: int = 200, seed: int = 42) -> float: """Compute recall@k between a flat index and an ANN index""" rng = np.random.default_rng(seed) @@ -176,4 +264,4 @@ def sweep_nprobe(ivfpq_bundle: IndexBundle, flat_bundle: IndexBundle, X: np.ndar ivfpq_bundle.index.nprobe = nprobe rec = recall_k(flat_bundle, ivfpq_bundle, X, k=10, nsamp=min(200, len(X))) results.append((nprobe, rec)) - return results \ No newline at end of file + return results diff --git a/ArtExtract_Soyoung/utils/data.py b/ArtExtract_Soyoung/utils/data.py index 34765154..45cb41db 100644 --- a/ArtExtract_Soyoung/utils/data.py +++ b/ArtExtract_Soyoung/utils/data.py @@ -1,76 +1,195 @@ from torch.utils.data import Dataset, DataLoader from torchvision import transforms -import torchvision.datasets as datasets from PIL import Image -import numpy as np import torch import os +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _validate_directory(path: str, label: str) -> None: + """Raise a clear error when a required directory is missing or not a dir.""" + if not os.path.exists(path): + raise FileNotFoundError(f"{label} directory does not exist: '{path}'") + if not os.path.isdir(path): + raise NotADirectoryError(f"{label} path is not a directory: '{path}'") + + +def _validate_image_mask_pairing(images: list, masks: dict, images_dir: str) -> None: + """ + Raise ValueError for any image that has no associated masks. + An unpaired image would silently produce an empty mask stack at runtime, + which is much harder to debug than a clear startup error. + """ + unpaired = [img for img, mask_list in masks.items() if len(mask_list) == 0] + if unpaired: + raise ValueError( + f"The following images in '{images_dir}' have no corresponding masks:\n" + + "\n".join(f" {name}" for name in unpaired) + ) + + +# --------------------------------------------------------------------------- +# Dataset +# --------------------------------------------------------------------------- + class UNetDataset(Dataset): - def __init__(self, images_dir, masks_dir, transform=None): + def __init__(self, images_dir: str, masks_dir: str, transform=None): + # ------------------------------------------------------------------ + # 1. FILE PATH VALIDATION — fail fast with actionable messages + # ------------------------------------------------------------------ + _validate_directory(images_dir, "Images") + _validate_directory(masks_dir, "Masks") + self.images_dir = images_dir self.masks_dir = masks_dir self.transform = transform - self.images = [f for f in sorted(os.listdir(images_dir)) if f.endswith('RGB.bmp') or f.endswith('.png') or f.endswith('.jpg') or f.endswith('.JPG')] - - # Ensure each image has corresponding 8 masks - self.masks = {img_name: sorted([f for f in os.listdir(masks_dir) if f.startswith(img_name.split('_RGB')[0])]) for img_name in self.images} - + + # Collect valid image files + valid_exts = ('RGB.bmp', '.png', '.jpg', '.JPG') + self.images = sorted( + f for f in os.listdir(images_dir) + if any(f.endswith(ext) for ext in valid_exts) + ) + + if len(self.images) == 0: + raise FileNotFoundError( + f"No valid image files found in images directory: '{images_dir}'. " + "Expected files ending with 'RGB.bmp', '.png', '.jpg', or '.JPG'." + ) + + # Map each image → its sorted list of mask files + self.masks = { + img_name: sorted( + f for f in os.listdir(masks_dir) + if f.startswith(img_name.split('_RGB')[0]) + ) + for img_name in self.images + } + + # Ensure every image actually has at least one mask + _validate_image_mask_pairing(self.images, self.masks, images_dir) + + # ------------------------------------------------------------------ + def __len__(self): return len(self.images) - + def __getitem__(self, idx): img_name = self.images[idx] img_path = os.path.join(self.images_dir, img_name) + + # Validate individual file existence at read time (handles deletions + # that occur after __init__ or symlinks that point nowhere) + if not os.path.isfile(img_path): + raise FileNotFoundError(f"Image file missing at runtime: '{img_path}'") + image = Image.open(img_path).convert('RGB') - - mask_names = self.masks[img_name] + masks = [] - for mask_name in mask_names: + for mask_name in self.masks[img_name]: mask_path = os.path.join(self.masks_dir, mask_name) + + if not os.path.isfile(mask_path): + raise FileNotFoundError(f"Mask file missing at runtime: '{mask_path}'") + mask = Image.open(mask_path) mode = mask.mode - - # If the mask image has more than one channel, convert it to grayscale - if mode == 'I;16': # Handle 16-bit images - mask = mask.point(lambda i: i * (1 / 255)).convert("L") - elif mode not in ['L', 'I']: # Convert non-grayscale masks to grayscale + + if mode == 'I;16': # 16-bit grayscale + mask = mask.point(lambda i: i * (1 / 255)).convert('L') + elif mode not in ('L', 'I'): mask = mask.convert('L') + masks.append(mask) if self.transform: image = self.transform(image) + # ------------------------------------------------------------------ + # 2. MASK NORMALIZATION FIX + # ToTensor() already scales uint8 PIL images from [0, 255] → [0, 1]. + # Applying an additional / 255.0 after the transform would push all + # values into [0, ~0.004], effectively zeroing out the masks. + # We apply the spatial transforms (Resize, flips) via a dedicated + # mask transform that deliberately omits ToTensor(), then convert + # manually — ensuring exactly one normalisation pass. + # ------------------------------------------------------------------ masks = [self.transform(mask) for mask in masks] + # At this point each mask is already a float tensor in [0, 1] + # courtesy of ToTensor() inside self.transform — no further + # division needed. + else: + # No transform supplied: convert to tensor and normalise once. + to_tensor = transforms.ToTensor() + masks = [to_tensor(mask) for mask in masks] + + masks = torch.stack(masks) # shape: (N, 1, H, W) + return image.float(), masks.float() + + +# --------------------------------------------------------------------------- +# DataLoader factory +# --------------------------------------------------------------------------- - # Convert masks to tensors and normalize pixel values - masks = torch.stack([mask.float() / 255.0 for mask in masks]) - return image.float(), masks +def load_datasets(train_path: str, val_path: str, seed: int = 42): + # ------------------------------------------------------------------ + # 1. FILE PATH VALIDATION for top-level dataset roots + # ------------------------------------------------------------------ + _validate_directory(train_path, "Training root") + _validate_directory(val_path, "Validation root") + train_images_dir = os.path.join(train_path, 'rgb_images') + train_masks_dir = os.path.join(train_path, 'ms_masks') + val_images_dir = os.path.join(val_path, 'rgb_images') + val_masks_dir = os.path.join(val_path, 'ms_masks') + + # Validate sub-directories before constructing datasets + _validate_directory(train_images_dir, "Train images") + _validate_directory(train_masks_dir, "Train masks") + _validate_directory(val_images_dir, "Validation images") + _validate_directory(val_masks_dir, "Validation masks") -def load_datasets(train_path, val_path): train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), - transforms.ToTensor() + transforms.ToTensor(), # scales to [0, 1] ]) - - transform = transforms.Compose([ - transforms.Resize((256, 256)), - transforms.ToTensor(), + + val_transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), ]) - - train_images_dir = train_path + 'rgb_images/' - train_masks_dir = train_path + 'ms_masks/' - val_images_dir = val_path + 'rgb_images/' - val_masks_dir = val_path + 'ms_masks/' - - # Create custom datasets - train_dataset = UNetDataset(images_dir=train_images_dir, masks_dir=train_masks_dir, transform=train_transform) - val_dataset = UNetDataset(images_dir=val_images_dir, masks_dir=val_masks_dir, transform=transform) - - # Create data loaders - # Use larger batch size if the dataset is bigger - train_loader = DataLoader(train_dataset, batch_size=8, pin_memory=True,shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=8, pin_memory=True,shuffle=False) - return train_loader, val_loader \ No newline at end of file + + train_dataset = UNetDataset(train_images_dir, train_masks_dir, transform=train_transform) + val_dataset = UNetDataset(val_images_dir, val_masks_dir, transform=val_transform) + + # ------------------------------------------------------------------ + # 3. REPRODUCIBLE DATALOADER WITH num_workers + # A fixed Generator ensures shuffle order is deterministic across runs + # when the same seed is used, making experiments reproducible. + # num_workers overlaps disk I/O with GPU compute, reducing idle time. + # ------------------------------------------------------------------ + num_workers = min(4, os.cpu_count() or 1) + generator = torch.Generator().manual_seed(seed) + + train_loader = DataLoader( + train_dataset, + batch_size=8, + shuffle=True, + pin_memory=True, + num_workers=num_workers, + generator=generator, + persistent_workers=num_workers > 0, + ) + val_loader = DataLoader( + val_dataset, + batch_size=8, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + persistent_workers=num_workers > 0, + ) + + return train_loader, val_loader diff --git a/ArtExtract_Soyoung/utils/metrics.py b/ArtExtract_Soyoung/utils/metrics.py index 17b91515..e40961c4 100644 --- a/ArtExtract_Soyoung/utils/metrics.py +++ b/ArtExtract_Soyoung/utils/metrics.py @@ -24,6 +24,10 @@ def __init__(self, feature_extractor, size_average=True): self.ssim_metric = ssim(data_range=1.0) def psnr(self, output, target): + if output.shape != target.shape: + raise ValueError( + f"Shape mismatch: output has shape {output.shape}, target has shape {target.shape}" + ) # Compute MSE per channel mse_per_channel = torch.mean((target - output) ** 2, dim=[0, 2, 3]) # MSE per channel @@ -72,4 +76,4 @@ def forward(self, output, target): psnr_value = self.psnr(output, target) lpips_value = self.lpips(output, target) ssim_value = self.ssim(output, target) - return psnr_value, lpips_value, ssim_value \ No newline at end of file + return psnr_value, lpips_value, ssim_value