Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions upscale.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import sys
import os.path
import glob
import cv2
import numpy
import torch
import architecture
import RRDBNet_arch as arch
import math

model_path = sys.argv[1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
model_path = 'models/RRDB_ESRGAN_x4.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu

inputDir = 'LR' # Input directory
Expand All @@ -17,7 +16,7 @@
upscalingAmount = 4 # Upscaling amount of the model

def upscaleImage( model, device, img ):
#Transpose
#Transpose
img = numpy.transpose( img[:, :, [2, 1, 0]], (2, 0, 1) )

imgTorch = torch.from_numpy( img ).float()
Expand All @@ -28,7 +27,7 @@ def upscaleImage( model, device, img ):
# Re-Transpose
return numpy.transpose(imgNumpy[[2, 1, 0], :, :], (1, 2, 0))

model = architecture.RRDB_Net( 3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
model = arch.RRDBNet( 3, 3, 64, 23, gc=32)
model.load_state_dict( torch.load( model_path ), strict=True )
model.eval()
for k, v in model.named_parameters():
Expand Down Expand Up @@ -86,4 +85,4 @@ def upscaleImage( model, device, img ):
imgOutput = ( imgOutput * 255.0 ).round()

# Save result
cv2.imwrite( '{:s}/{:s}.png'.format( outputDir, filename ), imgOutput )
cv2.imwrite( '{:s}/{:s}_upscaled.png'.format( outputDir, filename ), imgOutput )