diff --git a/ArtExtract_Mingchun/embedding.py b/ArtExtract_Mingchun/embedding.py index eccb6884..ef63218f 100644 --- a/ArtExtract_Mingchun/embedding.py +++ b/ArtExtract_Mingchun/embedding.py @@ -1,6 +1,8 @@ import os, glob import numpy as np import torch +import json +from sklearn.metrics.pairwise import cosine_similarity import warnings warnings.filterwarnings('ignore', category=RuntimeWarning) @@ -26,9 +28,19 @@ def __init__(self, images_dir, transform_img=None, files.extend(glob.glob(os.path.join(images_dir, p), recursive=True)) self.images = sorted(files) - if len(self.images) == 0: + if len(self.images)==0: raise ValueError(f"No images found under {images_dir} (extensions={ALLOWED_EXTS})") + # --- Dataset integrity check --- + valid_images = [] + for path in self.images: + try: + Image.open(path).verify() + valid_images.append(path) + except Exception: + print(f"Skipping corrupted image: {path}") + + self.images = valid_images sample = Image.open(self.images[0]).convert('RGB') if self.transform_img: sample = self.transform_img(sample) @@ -103,7 +115,7 @@ def build_dataset(images_dir, transform_img=None, return ds, in_channels def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'): - loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = min(4, os.cpu_count()), pin_memory=True) encoder.eval().to(device) all_embs = [] all_filenames = [] @@ -123,8 +135,53 @@ def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'): ids = np.array(all_filenames) np.save('./embedding/embeddings.npy', X) - np.savetxt('./embedding/ids.csv', ids, fmt='%s', delimiter=',') - return X, ids + + metadata = { + "num_embeddings": X.shape[0], + "embedding_dim": X.shape[1] + } + + with open("./embedding/meta.json","w") as f: + json.dump(metadata,f) + +def find_similar_embeddings(query_embedding, embeddings, top_k=5): + """ + Find the most similar embeddings using cosine similarity. + + Args: + query_embedding (np.ndarray): Embedding vector for the query. + embeddings (np.ndarray): Matrix of stored embeddings. + top_k (int): Number of similar embeddings to retrieve. + + Returns: + indices (np.ndarray): Indices of the most similar embeddings. + scores (np.ndarray): Similarity scores. + """ + + sims = cosine_similarity(query_embedding.reshape(1, -1), embeddings)[0] + indices = sims.argsort()[::-1][:top_k] + + return indices, sims[indices] + +def search_similar_images(query_image_path, embeddings, ids, encoder, device="cuda", top_k=5): + """ + Compute embedding for a query image and retrieve similar images. + """ + + img = Image.open(query_image_path).convert("RGB") + img_np = np.array(img) + + graph_data, segments = image_to_graph_rgb(img_np) + graph_data = graph_data.to(device) + + encoder.eval().to(device) + + with torch.no_grad(): + query_emb = encoder(graph_data).cpu().numpy() + + indices, scores = find_similar_embeddings(query_emb, embeddings, top_k) + + return ids[indices], scores def main(): diff --git a/ArtExtract_Mingchun/inference.py b/ArtExtract_Mingchun/inference.py index 8f4243e9..fd4537d4 100644 --- a/ArtExtract_Mingchun/inference.py +++ b/ArtExtract_Mingchun/inference.py @@ -1,8 +1,10 @@ +import os import torch import warnings warnings.filterwarnings("ignore") +import argparse -from utils.visulization import extract_hidden_art +from utils.visualization import extract_hidden_art from utils.data_graph import load_inference_datasets from model.extract_model import GATSiameseNetwork @@ -11,14 +13,37 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model - feature_dim = 14 # Example feature dimension, adjust as needed - model = GATSiameseNetwork(in_channels=feature_dim, hidden_channels=128, out_channels=32).to(device) # Ensure the model matches your architecture + # infer feature dimension dynamically + sample_loader = load_inference_datasets('./dataset/val', batch_size=1) + sample_data = next(iter(sample_loader)) + feature_dim = sample_data.x.shape[1] + + model = GATSiameseNetwork( + in_channels=feature_dim, + hidden_channels=128, + out_channels=32 + ).to(device) # Load pre-trained model weights - model.load_state_dict(torch.load('./checkpoints/GAT/best_model.pth', map_location=device)) + checkpoint_path = args.checkpoint + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Checkpoint not found: {checkpoint_path}. " + "Please train the model first." + ) + + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + model.eval() # Load validation dataset - val_loader = load_inference_datasets('./dataset/val', batch_size=1) + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", default="./dataset/val") + args = parser.parse_args() + + val_loader = load_inference_datasets(args.data_path, batch_size=1) # Extract and visualize hidden art features extract_hidden_art(model, val_loader, device, save_dir='./img', mode='diff', alpha=0.5) diff --git a/ArtExtract_Mingchun/model/embedding_model.py b/ArtExtract_Mingchun/model/embedding_model.py index ad67e5ef..9012b08a 100644 --- a/ArtExtract_Mingchun/model/embedding_model.py +++ b/ArtExtract_Mingchun/model/embedding_model.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch_geometric.utils import dropout_edge from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool +from torch_geometric.nn import GlobalAttention class GATBackbone(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.1): @@ -16,24 +18,44 @@ def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout= self.norm2 = nn.LayerNorm(out_channels * heads) self.norm3 = nn.LayerNorm(out_channels) + # Attention pooling + self.att_pool = GlobalAttention( + gate_nn=nn.Sequential( + nn.Linear(out_channels, 1) + ) + ) + def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch + # Edge dropout for regularization + edge_index, _ = dropout_edge(edge_index, p=0.1, training=self.training) - x = self.conv1(x, edge_index) - x = self.lrelu(self.norm1(x)) - x = self.dropout(x) + # ---- Layer 1 ---- + h = self.conv1(x, edge_index) + h = self.lrelu(self.norm1(h)) + h = self.dropout(h) + x = h - x = self.conv2(x, edge_index) - x = self.lrelu(self.norm2(x)) - x = self.dropout(x) + # ---- Layer 2 ---- + h = self.conv2(x, edge_index) + h = self.lrelu(self.norm2(h)) + h = self.dropout(h) + x = x + h if x.shape == h.shape else h # residual connection - x = self.conv3(x, edge_index) - x = self.lrelu(self.norm3(x)) - x = self.dropout(x) + # ---- Layer 3 ---- + h = self.conv3(x, edge_index) + h = self.lrelu(self.norm3(h)) + h = self.dropout(h) + x = x + h if x.shape == h.shape else h g_mean = global_mean_pool(x, batch) - g_max = global_max_pool(x, batch) - g = torch.cat([g_mean, g_max], dim=1) + g_max = global_max_pool(x, batch) + + # attention-based pooling + g_att = self.att_pool(x, batch) + + # combine multiple graph representations + g = torch.cat([g_mean, g_max, g_att], dim=1) return g, x, edge_index @@ -41,7 +63,7 @@ class GATSiameseNetworkEncoder(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, heads=1): super(GATSiameseNetworkEncoder, self).__init__() self.gat = GATBackbone(in_channels, hidden_channels, out_channels, heads) - in_proj = 2 * out_channels + in_proj = 3 * out_channels self.proj = nn.Sequential( nn.Linear(in_proj, hidden_channels), nn.ReLU(), diff --git a/ArtExtract_Mingchun/train.py b/ArtExtract_Mingchun/train.py index 530756e6..6248d9fb 100644 --- a/ArtExtract_Mingchun/train.py +++ b/ArtExtract_Mingchun/train.py @@ -115,7 +115,6 @@ def main(): val_losses = [] patience = 10 # Early stopping patience counter patience_counter = 0 - early_stopping = False # Train the model for epoch in range(50): # Example: train for 10 epochs diff --git a/ArtExtract_Mingchun/utils/build_graph.py b/ArtExtract_Mingchun/utils/build_graph.py index 5622bbfd..8b16334b 100644 --- a/ArtExtract_Mingchun/utils/build_graph.py +++ b/ArtExtract_Mingchun/utils/build_graph.py @@ -1,5 +1,6 @@ import numpy as np import torch +import matplotlib.pyplot as plt from scipy import ndimage from skimage import graph, segmentation, filters @@ -104,6 +105,10 @@ def extract_node(image, segments, target_feature_dim=None): min_val = np.min(region_pixels, axis=0) max_val = np.max(region_pixels, axis=0) + # compute texture feature using gradient magnitude + grad = np.gradient(region_pixels.astype(float), axis=0) + texture_val = np.mean(np.abs(grad), axis=0) + # Handle NaN and Inf values mean_val = np.nan_to_num(mean_val, nan=0.0, posinf=0.0, neginf=0.0) std_val = np.nan_to_num(std_val, nan=0.0, posinf=0.0, neginf=0.0) @@ -116,7 +121,14 @@ def extract_node(image, segments, target_feature_dim=None): center_yx = np.nan_to_num(center_yx, nan=0.0, posinf=0.0, neginf=0.0) # Construct the feature vector - feature_vec = np.concatenate([mean_val, std_val, min_val, max_val, center_yx]) # Multichannel image (H, W, C) + feature_vec = np.concatenate([ + mean_val, + std_val, + min_val, + max_val, + texture_val, + center_yx + ]) # Ensure the feature vector has the correct length features[i, :feature_vec.shape[0]] = feature_vec[:feature_dim] @@ -269,4 +281,22 @@ def image_to_graph_rgb(image, n_segments=5000, compactness=1, normalize_features channel_axis = -1 segments = segmentation.slic(image_slic, n_segments=n_segments, compactness=compactness, channel_axis=channel_axis) - return image_to_graph_infer(image, segments, normalize_features, target_feature_dim), segments \ No newline at end of file + return image_to_graph_infer(image, segments, normalize_features, target_feature_dim), segments + +def visualize_segments(image, segments): + """ + Visualize the SLIC superpixel segmentation used for graph construction. + This utility overlays the segmentation boundaries on the original image + to help inspect how the painting is partitioned into graph nodes. + """ + + plt.figure(figsize=(6, 6)) + plt.imshow(image) + + # overlay superpixel boundaries + plt.contour(segments, colors="red", linewidths=0.5) + + plt.title("Superpixel Graph Segmentation") + plt.axis("off") + + plt.show() \ No newline at end of file diff --git a/ArtExtract_Mingchun/utils/data_graph.py b/ArtExtract_Mingchun/utils/data_graph.py index 43e6ece7..5f428933 100644 --- a/ArtExtract_Mingchun/utils/data_graph.py +++ b/ArtExtract_Mingchun/utils/data_graph.py @@ -55,7 +55,11 @@ def __getitem__(self, idx): # --------RGB image-------- img_name = self.images[idx] img_path = os.path.join(self.images_dir, img_name) - image = Image.open(img_path).convert('RGB') + image = Image.open(path) + + # convert grayscale images to RGB + if image.mode != "RGB": + image = image.convert("RGB") if self.transform_img: image = self.transform_img(image) # (C, H, W) if isinstance(image, torch.Tensor): @@ -68,7 +72,7 @@ def __getitem__(self, idx): # --------Masks-------- mask_names = self.masks[img_name] mask_datas = [] - for mask_name in tqdm(mask_names): + for mask_name in mask_names: mask_path = os.path.join(self.masks_dir, mask_name) mask = Image.open(mask_path) if mask.mode == 'I;16': @@ -195,7 +199,10 @@ def __len__(self): def __getitem__(self, idx): img_name = self.images[idx] img_path = os.path.join(self.images_dir, img_name) - image = Image.open(img_path).convert('RGB') + try: + image = Image.open(img_path).convert('RGB') + except Exception as e: + raise RuntimeError(f"Failed to load image: {img_path}") from e if self.transform_img: image = self.transform_img(image) if isinstance(image, torch.Tensor): @@ -257,4 +264,5 @@ def load_inference_datasets(val_path, batch_size): transform_mask=val_mask_transform) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=inference_collate_fn) - return val_loader \ No newline at end of file + return val_loader + diff --git a/ArtExtract_Mingchun/utils/visulization.py b/ArtExtract_Mingchun/utils/visulization.py index dd0285da..c2e25a15 100644 --- a/ArtExtract_Mingchun/utils/visulization.py +++ b/ArtExtract_Mingchun/utils/visulization.py @@ -21,8 +21,12 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'): for node_idx, importance in enumerate(node_importance): heatmap[segments == node_idx] = importance - # Normalize the heatmap to [0, 1] range - heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) + #Improve numerical stability in heatmap normalization + denom = heatmap.max() - heatmap.min() + if denom < 1e-8: + heatmap_norm = np.zeros_like(heatmap) + else: + heatmap_norm = (heatmap - heatmap.min()) / denom # Transform heatmap to RGB using the specified colormap cmap_func = plt.get_cmap(cmap) @@ -39,6 +43,20 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'): return overlay +def save_overlay(image, heatmap, save_path): + """ + Save an overlay visualization of the heatmap on the image. + Useful for inspecting hidden structures detected by the model. + """ + + plt.figure(figsize=(6, 6)) + plt.imshow(image) + plt.imshow(heatmap, cmap="jet", alpha=0.5) + plt.axis("off") + + plt.savefig(save_path) + plt.close() + def extract_hidden_art(model, data_loader, device, save_dir=None, mode='diff', alpha=0.5): """Extract hidden art features from the model and visualize them. Args: