Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
32 changes: 32 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

import os
import glob
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from config import Config

class LOLDataset(Dataset):
def __init__(self, mode='train'):
all_low_folders = glob.glob("/kaggle/input/**/low", recursive=True)
if mode == 'train':
self.low_dir = [x for x in all_low_folders if 'our485' in x or 'train' in x][0]
else:
self.low_dir = [x for x in all_low_folders if 'eval15' in x or 'test' in x][0]
self.high_dir = self.low_dir.replace('low', 'high')
self.names = sorted(os.listdir(self.low_dir))
self.mode = mode

def __len__(self): return len(self.names)
def __getitem__(self, idx):
name = self.names[idx]
low = Image.open(os.path.join(self.low_dir, name)).convert('RGB')
high = Image.open(os.path.join(self.high_dir, name)).convert('RGB')
low = TF.resize(low, (Config.IMG_SIZE, Config.IMG_SIZE))
high = TF.resize(high, (Config.IMG_SIZE, Config.IMG_SIZE))
if self.mode == 'train':
if random.random() > 0.5: low = TF.hflip(low); high = TF.hflip(high)
if random.random() > 0.5: low = TF.vflip(low); high = TF.vflip(high)
return (TF.to_tensor(low) - 0.5) * 2, (TF.to_tensor(high) - 0.5) * 2
137 changes: 137 additions & 0 deletions run_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

import os
import sys
import time
import glob
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from tqdm import tqdm

# Add src to system path
sys.path.append(os.path.join(os.getcwd(), 'src'))

from config import Config
from model import PureDiffusionUNet
from diffusion import DiffusionEngine

def run_inference():
# Setup paths
test_dir = 'test'
output_dir = 'results'
checkpoint_path = 'checkpoints/latest.pth'

os.makedirs(output_dir, exist_ok=True)

# Device config
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Override Config class attribute so DiffusionEngine picks it up
Config.DEVICE = device

# Load Config (just for other params)
conf = Config()

# Initialize Model
print("Initializing model...")
model = PureDiffusionUNet().to(device)

# Load Weights
if os.path.exists(checkpoint_path):
print(f"Loading checkpoint from {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device)
if 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
else:
print(f"Error: Checkpoint not found at {checkpoint_path}")
return

# Initialize Engine
engine = DiffusionEngine()
engine.device = device # Ensure engine uses the correct device

# Process images
image_paths = glob.glob(os.path.join(test_dir, '*.*'))
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
image_paths = [p for p in image_paths if os.path.splitext(p)[1].lower() in valid_extensions]

if not image_paths:
print("No images found in test folder.")
return

print(f"Found {len(image_paths)} images to process.")

for img_path in tqdm(image_paths):
img_name = os.path.basename(img_path)
print(f"Processing {img_name}...")

try:
# Load and preprocess
original_img = Image.open(img_path).convert('RGB')
w, h = original_img.size

# Pad to multiples of 32 to ensure compatibility with UNet architecture
# The model has 3 downsampling layers -> 8x reduction, but 32 is safer.
multiple = 32
new_w = ((w + multiple - 1) // multiple) * multiple
new_h = ((h + multiple - 1) // multiple) * multiple

# We do NOT resize the image here. We will pad it.


# Re-reload to be sure
x_orig = TF.to_tensor(original_img).unsqueeze(0).to(device)

# Calculate padding
pad_w = new_w - w
pad_h = new_h - h

# Pad: (left, right, top, bottom)
# F.pad expects (len(last_dim)/2 pairs from last to first dim)
# Tensor is (B, C, H, W). We pad W and H.
import torch.nn.functional as F
if pad_w > 0 or pad_h > 0:
# pad last dim (W) with (0, pad_w), 2nd to last (H) with (0, pad_h)
x_padded = F.pad(x_orig, (0, pad_w, 0, pad_h), mode='reflect')
else:
x_padded = x_orig

low_light = (x_padded - 0.5) * 2

# Inference
# Start from noise? Current implementation of sample() starts from randn_like(low_light)
output = engine.sample(model, low_light)

# Post-process: (x + 1) / 2
output = (output + 1) * 0.5
output = torch.clamp(output, 0, 1)

# Crop back to original size
output = output[:, :, :h, :w]

# Save
output_pil = TF.to_pil_image(output.squeeze(0).cpu())
output_pil.save(os.path.join(output_dir, f"enhanced_{img_name}"))

# Release resources to keep system responsive
if device.type == 'mps':
torch.mps.empty_cache()
elif device.type == 'cuda':
torch.cuda.empty_cache()
time.sleep(1)

except RuntimeError as e:
if "out of memory" in str(e).lower():
print(f"OOM Processing {img_name} at resolution {original_img.size}. Skipping...")
torch.cuda.empty_cache() # If cuda
else:
print(f"Error processing {img_name}: {e}")
except Exception as e:
print(f"Error processing {img_name}: {e}")

print("Inference completed.")

if __name__ == "__main__":
run_inference()