Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions ArtExtract_Mingchun/embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os, glob
import numpy as np
import torch
import json
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)

Expand All @@ -26,9 +28,19 @@ def __init__(self, images_dir, transform_img=None,
files.extend(glob.glob(os.path.join(images_dir, p), recursive=True))
self.images = sorted(files)

if len(self.images) == 0:
if len(self.images)==0:
raise ValueError(f"No images found under {images_dir} (extensions={ALLOWED_EXTS})")

# --- Dataset integrity check ---
valid_images = []
for path in self.images:
try:
Image.open(path).verify()
valid_images.append(path)
except Exception:
print(f"Skipping corrupted image: {path}")

self.images = valid_images
sample = Image.open(self.images[0]).convert('RGB')
if self.transform_img:
sample = self.transform_img(sample)
Expand Down Expand Up @@ -103,7 +115,7 @@ def build_dataset(images_dir, transform_img=None,
return ds, in_channels

def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'):
loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = min(4, os.cpu_count()), pin_memory=True)
encoder.eval().to(device)
all_embs = []
all_filenames = []
Expand All @@ -123,8 +135,53 @@ def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'):
ids = np.array(all_filenames)

np.save('./embedding/embeddings.npy', X)
np.savetxt('./embedding/ids.csv', ids, fmt='%s', delimiter=',')
return X, ids

metadata = {
"num_embeddings": X.shape[0],
"embedding_dim": X.shape[1]
}

with open("./embedding/meta.json","w") as f:
json.dump(metadata,f)

def find_similar_embeddings(query_embedding, embeddings, top_k=5):
"""
Find the most similar embeddings using cosine similarity.

Args:
query_embedding (np.ndarray): Embedding vector for the query.
embeddings (np.ndarray): Matrix of stored embeddings.
top_k (int): Number of similar embeddings to retrieve.

Returns:
indices (np.ndarray): Indices of the most similar embeddings.
scores (np.ndarray): Similarity scores.
"""

sims = cosine_similarity(query_embedding.reshape(1, -1), embeddings)[0]
indices = sims.argsort()[::-1][:top_k]

return indices, sims[indices]

def search_similar_images(query_image_path, embeddings, ids, encoder, device="cuda", top_k=5):
"""
Compute embedding for a query image and retrieve similar images.
"""

img = Image.open(query_image_path).convert("RGB")
img_np = np.array(img)

graph_data, segments = image_to_graph_rgb(img_np)
graph_data = graph_data.to(device)

encoder.eval().to(device)

with torch.no_grad():
query_emb = encoder(graph_data).cpu().numpy()

indices, scores = find_similar_embeddings(query_emb, embeddings, top_k)

return ids[indices], scores


def main():
Expand Down
35 changes: 30 additions & 5 deletions ArtExtract_Mingchun/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import torch
import warnings
warnings.filterwarnings("ignore")
import argparse

from utils.visulization import extract_hidden_art
from utils.visualization import extract_hidden_art
from utils.data_graph import load_inference_datasets
from model.extract_model import GATSiameseNetwork

Expand All @@ -11,14 +13,37 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
feature_dim = 14 # Example feature dimension, adjust as needed
model = GATSiameseNetwork(in_channels=feature_dim, hidden_channels=128, out_channels=32).to(device) # Ensure the model matches your architecture
# infer feature dimension dynamically
sample_loader = load_inference_datasets('./dataset/val', batch_size=1)
sample_data = next(iter(sample_loader))
feature_dim = sample_data.x.shape[1]

model = GATSiameseNetwork(
in_channels=feature_dim,
hidden_channels=128,
out_channels=32
).to(device)

# Load pre-trained model weights
model.load_state_dict(torch.load('./checkpoints/GAT/best_model.pth', map_location=device))
checkpoint_path = args.checkpoint

if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found: {checkpoint_path}. "
"Please train the model first."
)

model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# Load validation dataset
val_loader = load_inference_datasets('./dataset/val', batch_size=1)
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--data_path", default="./dataset/val")
args = parser.parse_args()

val_loader = load_inference_datasets(args.data_path, batch_size=1)

# Extract and visualize hidden art features
extract_hidden_art(model, val_loader, device, save_dir='./img', mode='diff', alpha=0.5)
Expand Down
46 changes: 34 additions & 12 deletions ArtExtract_Mingchun/model/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
from torch_geometric.utils import dropout_edge
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool
from torch_geometric.nn import GlobalAttention

class GATBackbone(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=0.1):
Expand All @@ -16,32 +18,52 @@ def __init__(self, in_channels, hidden_channels, out_channels, heads=4, dropout=
self.norm2 = nn.LayerNorm(out_channels * heads)
self.norm3 = nn.LayerNorm(out_channels)

# Attention pooling
self.att_pool = GlobalAttention(
gate_nn=nn.Sequential(
nn.Linear(out_channels, 1)
)
)

def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
# Edge dropout for regularization
edge_index, _ = dropout_edge(edge_index, p=0.1, training=self.training)

x = self.conv1(x, edge_index)
x = self.lrelu(self.norm1(x))
x = self.dropout(x)
# ---- Layer 1 ----
h = self.conv1(x, edge_index)
h = self.lrelu(self.norm1(h))
h = self.dropout(h)
x = h

x = self.conv2(x, edge_index)
x = self.lrelu(self.norm2(x))
x = self.dropout(x)
# ---- Layer 2 ----
h = self.conv2(x, edge_index)
h = self.lrelu(self.norm2(h))
h = self.dropout(h)
x = x + h if x.shape == h.shape else h # residual connection

x = self.conv3(x, edge_index)
x = self.lrelu(self.norm3(x))
x = self.dropout(x)
# ---- Layer 3 ----
h = self.conv3(x, edge_index)
h = self.lrelu(self.norm3(h))
h = self.dropout(h)
x = x + h if x.shape == h.shape else h

g_mean = global_mean_pool(x, batch)
g_max = global_max_pool(x, batch)
g = torch.cat([g_mean, g_max], dim=1)
g_max = global_max_pool(x, batch)

# attention-based pooling
g_att = self.att_pool(x, batch)

# combine multiple graph representations
g = torch.cat([g_mean, g_max, g_att], dim=1)

return g, x, edge_index

class GATSiameseNetworkEncoder(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=1):
super(GATSiameseNetworkEncoder, self).__init__()
self.gat = GATBackbone(in_channels, hidden_channels, out_channels, heads)
in_proj = 2 * out_channels
in_proj = 3 * out_channels
self.proj = nn.Sequential(
nn.Linear(in_proj, hidden_channels),
nn.ReLU(),
Expand Down
1 change: 0 additions & 1 deletion ArtExtract_Mingchun/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def main():
val_losses = []
patience = 10 # Early stopping patience counter
patience_counter = 0
early_stopping = False

# Train the model
for epoch in range(50): # Example: train for 10 epochs
Expand Down
34 changes: 32 additions & 2 deletions ArtExtract_Mingchun/utils/build_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
import matplotlib.pyplot as plt

from scipy import ndimage
from skimage import graph, segmentation, filters
Expand Down Expand Up @@ -104,6 +105,10 @@ def extract_node(image, segments, target_feature_dim=None):
min_val = np.min(region_pixels, axis=0)
max_val = np.max(region_pixels, axis=0)

# compute texture feature using gradient magnitude
grad = np.gradient(region_pixels.astype(float), axis=0)
texture_val = np.mean(np.abs(grad), axis=0)

# Handle NaN and Inf values
mean_val = np.nan_to_num(mean_val, nan=0.0, posinf=0.0, neginf=0.0)
std_val = np.nan_to_num(std_val, nan=0.0, posinf=0.0, neginf=0.0)
Expand All @@ -116,7 +121,14 @@ def extract_node(image, segments, target_feature_dim=None):
center_yx = np.nan_to_num(center_yx, nan=0.0, posinf=0.0, neginf=0.0)

# Construct the feature vector
feature_vec = np.concatenate([mean_val, std_val, min_val, max_val, center_yx]) # Multichannel image (H, W, C)
feature_vec = np.concatenate([
mean_val,
std_val,
min_val,
max_val,
texture_val,
center_yx
])
# Ensure the feature vector has the correct length
features[i, :feature_vec.shape[0]] = feature_vec[:feature_dim]

Expand Down Expand Up @@ -269,4 +281,22 @@ def image_to_graph_rgb(image, n_segments=5000, compactness=1, normalize_features
channel_axis = -1
segments = segmentation.slic(image_slic, n_segments=n_segments, compactness=compactness, channel_axis=channel_axis)

return image_to_graph_infer(image, segments, normalize_features, target_feature_dim), segments
return image_to_graph_infer(image, segments, normalize_features, target_feature_dim), segments

def visualize_segments(image, segments):
"""
Visualize the SLIC superpixel segmentation used for graph construction.
This utility overlays the segmentation boundaries on the original image
to help inspect how the painting is partitioned into graph nodes.
"""

plt.figure(figsize=(6, 6))
plt.imshow(image)

# overlay superpixel boundaries
plt.contour(segments, colors="red", linewidths=0.5)

plt.title("Superpixel Graph Segmentation")
plt.axis("off")

plt.show()
16 changes: 12 additions & 4 deletions ArtExtract_Mingchun/utils/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __getitem__(self, idx):
# --------RGB image--------
img_name = self.images[idx]
img_path = os.path.join(self.images_dir, img_name)
image = Image.open(img_path).convert('RGB')
image = Image.open(path)

# convert grayscale images to RGB
if image.mode != "RGB":
image = image.convert("RGB")
if self.transform_img:
image = self.transform_img(image) # (C, H, W)
if isinstance(image, torch.Tensor):
Expand All @@ -68,7 +72,7 @@ def __getitem__(self, idx):
# --------Masks--------
mask_names = self.masks[img_name]
mask_datas = []
for mask_name in tqdm(mask_names):
for mask_name in mask_names:
mask_path = os.path.join(self.masks_dir, mask_name)
mask = Image.open(mask_path)
if mask.mode == 'I;16':
Expand Down Expand Up @@ -195,7 +199,10 @@ def __len__(self):
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.images_dir, img_name)
image = Image.open(img_path).convert('RGB')
try:
image = Image.open(img_path).convert('RGB')
except Exception as e:
raise RuntimeError(f"Failed to load image: {img_path}") from e
if self.transform_img:
image = self.transform_img(image)
if isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -257,4 +264,5 @@ def load_inference_datasets(val_path, batch_size):
transform_mask=val_mask_transform)

val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=inference_collate_fn)
return val_loader
return val_loader

22 changes: 20 additions & 2 deletions ArtExtract_Mingchun/utils/visulization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'):
for node_idx, importance in enumerate(node_importance):
heatmap[segments == node_idx] = importance

# Normalize the heatmap to [0, 1] range
heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
#Improve numerical stability in heatmap normalization
denom = heatmap.max() - heatmap.min()
if denom < 1e-8:
heatmap_norm = np.zeros_like(heatmap)
else:
heatmap_norm = (heatmap - heatmap.min()) / denom

# Transform heatmap to RGB using the specified colormap
cmap_func = plt.get_cmap(cmap)
Expand All @@ -39,6 +43,20 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'):

return overlay

def save_overlay(image, heatmap, save_path):
"""
Save an overlay visualization of the heatmap on the image.
Useful for inspecting hidden structures detected by the model.
"""

plt.figure(figsize=(6, 6))
plt.imshow(image)
plt.imshow(heatmap, cmap="jet", alpha=0.5)
plt.axis("off")

plt.savefig(save_path)
plt.close()

def extract_hidden_art(model, data_loader, device, save_dir=None, mode='diff', alpha=0.5):
"""Extract hidden art features from the model and visualize them.
Args:
Expand Down