diff --git a/ArtExtract_Mingchun/embedding.py b/ArtExtract_Mingchun/embedding.py index eccb6884..93da0727 100644 --- a/ArtExtract_Mingchun/embedding.py +++ b/ArtExtract_Mingchun/embedding.py @@ -103,7 +103,7 @@ def build_dataset(images_dir, transform_img=None, return ds, in_channels def extract_embeddings(encoder, dataset, batch_size=64, device='cuda'): - loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + loader = GeoDataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers = min(4, os.cpu_count()), pin_memory=True) encoder.eval().to(device) all_embs = [] all_filenames = [] diff --git a/ArtExtract_Mingchun/inference.py b/ArtExtract_Mingchun/inference.py index 8f4243e9..2253cf9b 100644 --- a/ArtExtract_Mingchun/inference.py +++ b/ArtExtract_Mingchun/inference.py @@ -1,8 +1,9 @@ +import os import torch import warnings warnings.filterwarnings("ignore") -from utils.visulization import extract_hidden_art +from utils.visualization import extract_hidden_art from utils.data_graph import load_inference_datasets from model.extract_model import GATSiameseNetwork @@ -11,14 +12,28 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model - feature_dim = 14 # Example feature dimension, adjust as needed - model = GATSiameseNetwork(in_channels=feature_dim, hidden_channels=128, out_channels=32).to(device) # Ensure the model matches your architecture + # infer feature dimension dynamically + sample_loader = load_inference_datasets('./dataset/val', batch_size=1) + sample_data = next(iter(sample_loader)) + feature_dim = sample_data.x.shape[1] + + model = GATSiameseNetwork( + in_channels=feature_dim, + hidden_channels=128, + out_channels=32 + ).to(device) # Load pre-trained model weights model.load_state_dict(torch.load('./checkpoints/GAT/best_model.pth', map_location=device)) # Load validation dataset - val_loader = load_inference_datasets('./dataset/val', batch_size=1) + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", default="./dataset/val") + args = parser.parse_args() + + val_loader = load_inference_datasets(args.data_path, batch_size=1) # Extract and visualize hidden art features extract_hidden_art(model, val_loader, device, save_dir='./img', mode='diff', alpha=0.5) diff --git a/ArtExtract_Mingchun/train.py b/ArtExtract_Mingchun/train.py index 530756e6..6248d9fb 100644 --- a/ArtExtract_Mingchun/train.py +++ b/ArtExtract_Mingchun/train.py @@ -115,7 +115,6 @@ def main(): val_losses = [] patience = 10 # Early stopping patience counter patience_counter = 0 - early_stopping = False # Train the model for epoch in range(50): # Example: train for 10 epochs diff --git a/ArtExtract_Mingchun/utils/data_graph.py b/ArtExtract_Mingchun/utils/data_graph.py index 43e6ece7..311bf8c7 100644 --- a/ArtExtract_Mingchun/utils/data_graph.py +++ b/ArtExtract_Mingchun/utils/data_graph.py @@ -68,7 +68,7 @@ def __getitem__(self, idx): # --------Masks-------- mask_names = self.masks[img_name] mask_datas = [] - for mask_name in tqdm(mask_names): + for mask_name in mask_names: mask_path = os.path.join(self.masks_dir, mask_name) mask = Image.open(mask_path) if mask.mode == 'I;16': @@ -195,7 +195,10 @@ def __len__(self): def __getitem__(self, idx): img_name = self.images[idx] img_path = os.path.join(self.images_dir, img_name) - image = Image.open(img_path).convert('RGB') + try: + image = Image.open(img_path).convert('RGB') + except Exception as e: + raise RuntimeError(f"Failed to load image: {img_path}") from e if self.transform_img: image = self.transform_img(image) if isinstance(image, torch.Tensor): @@ -257,4 +260,5 @@ def load_inference_datasets(val_path, batch_size): transform_mask=val_mask_transform) val_loader = DataLoader(val_dataset, batch_size, shuffle=False, collate_fn=inference_collate_fn) - return val_loader \ No newline at end of file + return val_loader + diff --git a/ArtExtract_Mingchun/utils/visulization.py b/ArtExtract_Mingchun/utils/visulization.py index dd0285da..257143ab 100644 --- a/ArtExtract_Mingchun/utils/visulization.py +++ b/ArtExtract_Mingchun/utils/visulization.py @@ -21,8 +21,12 @@ def overlay_node(image, segments, node_importance, alpha=0.5, cmap='jet'): for node_idx, importance in enumerate(node_importance): heatmap[segments == node_idx] = importance - # Normalize the heatmap to [0, 1] range - heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8) + #Improve numerical stability in heatmap normalization + denom = heatmap.max() - heatmap.min() + if denom < 1e-8: + heatmap_norm = np.zeros_like(heatmap) + else: + heatmap_norm = (heatmap - heatmap.min()) / denom # Transform heatmap to RGB using the specified colormap cmap_func = plt.get_cmap(cmap)