diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..ac46a14 Binary files /dev/null and b/.DS_Store differ diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..1a55a51 --- /dev/null +++ b/dataset.py @@ -0,0 +1,32 @@ + +import os +import glob +import random +from PIL import Image +import torch +from torch.utils.data import Dataset +import torchvision.transforms.functional as TF +from config import Config + +class LOLDataset(Dataset): + def __init__(self, mode='train'): + all_low_folders = glob.glob("/kaggle/input/**/low", recursive=True) + if mode == 'train': + self.low_dir = [x for x in all_low_folders if 'our485' in x or 'train' in x][0] + else: + self.low_dir = [x for x in all_low_folders if 'eval15' in x or 'test' in x][0] + self.high_dir = self.low_dir.replace('low', 'high') + self.names = sorted(os.listdir(self.low_dir)) + self.mode = mode + + def __len__(self): return len(self.names) + def __getitem__(self, idx): + name = self.names[idx] + low = Image.open(os.path.join(self.low_dir, name)).convert('RGB') + high = Image.open(os.path.join(self.high_dir, name)).convert('RGB') + low = TF.resize(low, (Config.IMG_SIZE, Config.IMG_SIZE)) + high = TF.resize(high, (Config.IMG_SIZE, Config.IMG_SIZE)) + if self.mode == 'train': + if random.random() > 0.5: low = TF.hflip(low); high = TF.hflip(high) + if random.random() > 0.5: low = TF.vflip(low); high = TF.vflip(high) + return (TF.to_tensor(low) - 0.5) * 2, (TF.to_tensor(high) - 0.5) * 2 diff --git a/run_inference.py b/run_inference.py new file mode 100644 index 0000000..6ba9dd1 --- /dev/null +++ b/run_inference.py @@ -0,0 +1,137 @@ + +import os +import sys +import time +import glob +import torch +import torchvision.transforms.functional as TF +from PIL import Image +from tqdm import tqdm + +# Add src to system path +sys.path.append(os.path.join(os.getcwd(), 'src')) + +from config import Config +from model import PureDiffusionUNet +from diffusion import DiffusionEngine + +def run_inference(): + # Setup paths + test_dir = 'test' + output_dir = 'results' + checkpoint_path = 'checkpoints/latest.pth' + + os.makedirs(output_dir, exist_ok=True) + + # Device config + device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') + print(f"Using device: {device}") + + # Override Config class attribute so DiffusionEngine picks it up + Config.DEVICE = device + + # Load Config (just for other params) + conf = Config() + + # Initialize Model + print("Initializing model...") + model = PureDiffusionUNet().to(device) + + # Load Weights + if os.path.exists(checkpoint_path): + print(f"Loading checkpoint from {checkpoint_path}") + state_dict = torch.load(checkpoint_path, map_location=device) + if 'model' in state_dict: + state_dict = state_dict['model'] + model.load_state_dict(state_dict) + else: + print(f"Error: Checkpoint not found at {checkpoint_path}") + return + + # Initialize Engine + engine = DiffusionEngine() + engine.device = device # Ensure engine uses the correct device + + # Process images + image_paths = glob.glob(os.path.join(test_dir, '*.*')) + valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'} + image_paths = [p for p in image_paths if os.path.splitext(p)[1].lower() in valid_extensions] + + if not image_paths: + print("No images found in test folder.") + return + + print(f"Found {len(image_paths)} images to process.") + + for img_path in tqdm(image_paths): + img_name = os.path.basename(img_path) + print(f"Processing {img_name}...") + + try: + # Load and preprocess + original_img = Image.open(img_path).convert('RGB') + w, h = original_img.size + + # Pad to multiples of 32 to ensure compatibility with UNet architecture + # The model has 3 downsampling layers -> 8x reduction, but 32 is safer. + multiple = 32 + new_w = ((w + multiple - 1) // multiple) * multiple + new_h = ((h + multiple - 1) // multiple) * multiple + + # We do NOT resize the image here. We will pad it. + + + # Re-reload to be sure + x_orig = TF.to_tensor(original_img).unsqueeze(0).to(device) + + # Calculate padding + pad_w = new_w - w + pad_h = new_h - h + + # Pad: (left, right, top, bottom) + # F.pad expects (len(last_dim)/2 pairs from last to first dim) + # Tensor is (B, C, H, W). We pad W and H. + import torch.nn.functional as F + if pad_w > 0 or pad_h > 0: + # pad last dim (W) with (0, pad_w), 2nd to last (H) with (0, pad_h) + x_padded = F.pad(x_orig, (0, pad_w, 0, pad_h), mode='reflect') + else: + x_padded = x_orig + + low_light = (x_padded - 0.5) * 2 + + # Inference + # Start from noise? Current implementation of sample() starts from randn_like(low_light) + output = engine.sample(model, low_light) + + # Post-process: (x + 1) / 2 + output = (output + 1) * 0.5 + output = torch.clamp(output, 0, 1) + + # Crop back to original size + output = output[:, :, :h, :w] + + # Save + output_pil = TF.to_pil_image(output.squeeze(0).cpu()) + output_pil.save(os.path.join(output_dir, f"enhanced_{img_name}")) + + # Release resources to keep system responsive + if device.type == 'mps': + torch.mps.empty_cache() + elif device.type == 'cuda': + torch.cuda.empty_cache() + time.sleep(1) + + except RuntimeError as e: + if "out of memory" in str(e).lower(): + print(f"OOM Processing {img_name} at resolution {original_img.size}. Skipping...") + torch.cuda.empty_cache() # If cuda + else: + print(f"Error processing {img_name}: {e}") + except Exception as e: + print(f"Error processing {img_name}: {e}") + + print("Inference completed.") + +if __name__ == "__main__": + run_inference()