Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ArtExtract_Mingchun/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def build_dataset(images_dir, transform_img=None,
return ds, in_channels

def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'):
loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = min(4, os.cpu_count()), pin_memory=True)
encoder.eval().to(device)
all_embs = []
all_filenames = []
Expand Down
23 changes: 19 additions & 4 deletions ArtExtract_Mingchun/inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import torch
import warnings
warnings.filterwarnings("ignore")

from utils.visulization import extract_hidden_art
from utils.visualization import extract_hidden_art
from utils.data_graph import load_inference_datasets
from model.extract_model import GATSiameseNetwork

Expand All @@ -11,14 +12,28 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
feature_dim = 14 # Example feature dimension, adjust as needed
model = GATSiameseNetwork(in_channels=feature_dim, hidden_channels=128, out_channels=32).to(device) # Ensure the model matches your architecture
# infer feature dimension dynamically
sample_loader = load_inference_datasets('./dataset/val', batch_size=1)
sample_data = next(iter(sample_loader))
feature_dim = sample_data.x.shape[1]

model = GATSiameseNetwork(
in_channels=feature_dim,
hidden_channels=128,
out_channels=32
).to(device)

# Load pre-trained model weights
model.load_state_dict(torch.load('./checkpoints/GAT/best_model.pth', map_location=device))

# Load validation dataset
val_loader = load_inference_datasets('./dataset/val', batch_size=1)
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--data_path", default="./dataset/val")
args = parser.parse_args()

val_loader = load_inference_datasets(args.data_path, batch_size=1)

# Extract and visualize hidden art features
extract_hidden_art(model, val_loader, device, save_dir='./img', mode='diff', alpha=0.5)
Expand Down
1 change: 0 additions & 1 deletion ArtExtract_Mingchun/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def main():
val_losses = []
patience = 10 # Early stopping patience counter
patience_counter = 0
early_stopping = False

# Train the model
for epoch in range(50): # Example: train for 10 epochs
Expand Down
10 changes: 7 additions & 3 deletions ArtExtract_Mingchun/utils/data_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __getitem__(self, idx):
# --------Masks--------
mask_names = self.masks[img_name]
mask_datas = []
for mask_name in tqdm(mask_names):
for mask_name in mask_names:
mask_path = os.path.join(self.masks_dir, mask_name)
mask = Image.open(mask_path)
if mask.mode == 'I;16':
Expand Down Expand Up @@ -195,7 +195,10 @@ def __len__(self):
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.images_dir, img_name)
image = Image.open(img_path).convert('RGB')
try:
image = Image.open(img_path).convert('RGB')
except Exception as e:
raise RuntimeError(f"Failed to load image: {img_path}") from e
if self.transform_img:
image = self.transform_img(image)
if isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -257,4 +260,5 @@ def load_inference_datasets(val_path, batch_size):
transform_mask=val_mask_transform)

val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=inference_collate_fn)
return val_loader
return val_loader

8 changes: 6 additions & 2 deletions ArtExtract_Mingchun/utils/visulization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'):
for node_idx, importance in enumerate(node_importance):
heatmap[segments == node_idx] = importance

# Normalize the heatmap to [0, 1] range
heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
#Improve numerical stability in heatmap normalization
denom = heatmap.max() - heatmap.min()
if denom < 1e-8:
heatmap_norm = np.zeros_like(heatmap)
else:
heatmap_norm = (heatmap - heatmap.min()) / denom

# Transform heatmap to RGB using the specified colormap
cmap_func = plt.get_cmap(cmap)
Expand Down