diff --git a/README.md b/README.md index f2f8578..be487a6 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,26 @@ # Fast SR-UNet -This repository contains the implementation of [1]. It is an architecture comprised with a GAN-based training procedure for obtaining -a fast neural network which enable better bitrate performances respect to the H.265 codec for the same quality, or better quality at the same -bitrate. + +This repository contains the implementation of [1]. It is an architecture comprised with a GAN-based training procedure +for obtaining a fast neural network which enable better bitrate performances respect to the H.265 codec for the same +quality, or better quality at the same bitrate. #### Requirements: + - Installing CUDA with torchvision and torch: `$ conda install pytorch torchvision cudatoolkit=10.2 -c pytorch -c` - [LPIPS](https://github.com/richzhang/PerceptualSimilarity): `$ pip install lpips` -- FFMpeg compiled with H.265 codec and also VMAF metric. My version is included in the `helper/` directory but it won't likely work. - For references check [the official compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu) and +- FFMpeg compiled with H.265 codec and also VMAF metric. My version is included in the `helper/` directory but it won't + likely work. For references + check [the official compilation guide](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu) and the [VMAF GitHub Repository](https://github.com/Netflix/vmaf). - #### The dataset: -First, the dataset we use for training is the [BVI-DVC](https://arxiv.org/pdf/2003.13552). For preparing the dataset there are two helper script, -`compress_train_videos.sh` for spatially compressing and encoding each video, then with `extract_train_frames.sh` the dataset can be prepared. + +First, the dataset we use for training is the [BVI-DVC](https://arxiv.org/pdf/2003.13552). For preparing the dataset +there are two helper script, +`compress_train_videos.sh` for spatially compressing and encoding each video, then with `extract_train_frames.sh` the +dataset can be prepared. The train dataset should follow this naming scheme (assuming the videos are encoded with CRF 23): + ``` [DATASET_DIR]/ frames_HQ/ @@ -38,23 +44,32 @@ The train dataset should follow this naming scheme (assuming the videos are enco [clipNameN]/ ... ``` + #### Training the model: -To train the SR-ResNet described in the paper for 2x Super Resolution (as used in the model for the 540p -> 1080p upscaling), you can use this command. +To train the SR-ResNet described in the paper for 2x Super Resolution (as used in the model for the 540p -> 1080p +upscaling), you can use this command. + ``` $ python train.py --arch srunet --device 0 --upscale 2 --export [EXPORT_DIR] \ --epochs 80 --dataset [DATASET_DIR] --crf 23 ``` -Or, since most of these arguments are defaults, simply + +Or, since most of these arguments are defaults, simply + ``` $ python train.py --dataset [DATASET_DIR] ``` + For more information about the other parameters, inspect `utils.py` or try + ``` & python train.py -h ``` -However, in the bandwidth experiments we employed a lighter model, trained on a range of CRFs for performing an easier 1.5x upscale (720p -> 1080p). It is obtainable with the following command: +However, in the bandwidth experiments we employed a lighter model, trained on a range of CRFs for performing an easier +1.5x upscale (720p -> 1080p). It is obtainable with the following command: + ``` $ python train.py --arch srunet --layer_multiplier 0.7 --n_filters 48 --downsample 0.75 --device 0 \ --upscale 2 --export [EXPORT_DIR] --epochs 80 --dataset [DATASET_DIR] --crf [CRF] @@ -63,8 +78,9 @@ $ python train.py --arch srunet --layer_multiplier 0.7 --n_filters 48 --downsamp #### Testing the models: You may want to test your models. In our paper we tested on the 1080p clips available from the (Derf's Collection)[https://media.xiph.org/video/derf/] - in Y4M format. For preparing the test set (of encoded clips) you can use the `compress_test_videos.sh` helper script. + in Y4M format. For preparing the test set (of encoded clips) you can use the `compress_test_videos.sh` helper script. This time, the test set will be structured as follows, and there is no need of extracting each frame: + ``` [TEST_DIR]/ encoded540CRF23/ @@ -79,21 +95,46 @@ This time, the test set will be structured as follows, and there is no need of e ... touchdown_pass_1080p.y4m ``` -Finally, for testing a model (e.g. the one performing 1.5x upscale) which name is _[MODEL_NAME]_ you can use the command: + +Finally, for testing a model (e.g. the one performing 1.5x upscale) which name is _[MODEL_NAME]_ you can use the +command: + ``` -$ python evaluate_model.py --model [MODEL_NAME] --arch srunet --layer_multiplier 0.7 --n_filters 48 \ +$ python evaluate_model.py --model [MODEL_NAME] --arch srunet --layer_mult 0.7 --n_filters 48 \ --downsample 0.75 --device 0 --upscale 2 --crf 23 --test_dir [TEST_DIR] --testinputres 720 --testoutputres 1080 ``` + Ultimately will be printed on screen the experimental results, and also will be saved a .csv file contained these infos. +#### Inference with the model using `render.py` + +You can use the script `render.py` for using the model in real-time to upscale your clips. Examples: + +- For 2x upscaling + ``` + $ python render.py --clipname path/to/clip.mp4 --model models/srunet_2x_crf23.pth + ``` +- For 1.5x upscaling + ``` + $ python render.py --clipname path/to/clip.mp4 --model models/srunet_1.5x_crf23.pth --layer_mult 0.7 --n_filters 48 --downsample 0.75 + ``` + +You will notice that by default the output is split in two halves: on the left there is the input, on the right there is +the upscaled version. You can show only the upscaled version by adding the flag `--show-only-upscaled`. + +_About performances, my GTX 1080 Ti is enough for rendering at 30fps when upscaling 540p -> 1080p and 25fos when 720p -> +1080p. Note that in the paper we also employed [Nvidia Apex](https://nvidia.github.io/apex/) for speeding up inference +times._ + ## Examples ![TajMahal](pics/tajmahal.png) Check [this link](https://bit.ly/3aGPzMW) for the complete clip. ![Venice](pics/venice.png) - #### References: + This code is the implementation of my Master Degree Thesis, from which my supervisors and I wrote the paper: -- [1] Fast video visual quality and resolution improvement using SR-UNet. - Authors Federico Vaccaro, Marco Bertini, Tiberio Uricchio, and Alberto Del Bimbo (accepted at ACM MM '21) \ No newline at end of file + +- [1] Fast video visual quality and resolution improvement using SR-UNet. Authors Federico Vaccaro, Marco Bertini, + Tiberio Uricchio, and Alberto Del Bimbo (accepted at ACM MM '21) \ No newline at end of file diff --git a/pytorch_unet.py b/pytorch_unet.py index 0f44c6f..a0f1d07 100644 --- a/pytorch_unet.py +++ b/pytorch_unet.py @@ -143,8 +143,8 @@ def __init__(self, in_dim=3, n_class=3, n_filters=32, downsample=None, residual= else: self.conv_last = nn.Conv2d(n_filters, 3, kernel_size=1) - if downsample is not None: - self.downsample = 'interp' # nn.UpsamplingBilinear2d(scale_factor=downsample) + if downsample is not None and downsample != 1.0: + self.downsample = nn.Upsample(scale_factor=downsample, mode='bicubic', align_corners=True) else: self.downsample = nn.Identity() self.layers = [self.dconv_down1, self.dconv_down2, self.dconv_down3, self.dconv_down4, self.dconv_up3, @@ -190,8 +190,7 @@ def forward(self, input): mode='bicubic') x = torch.clamp(x, min=-1, max=1) - return x - # return torch.clamp(F.interpolate(x, mode='bicubic', scale_factor=0.75, align_corners=True), min=-1, max=1)# self.downsample(x) + return torch.clamp(self.downsample(x), min=-1, max=1) def reparametrize(self): for layer in self.layers: @@ -244,7 +243,7 @@ def __init__(self, in_dim=3, n_class=3, downsample=None, residual=False, batchno n_blocks=3 * layer_multiplier) self.maxpool = nn.MaxPool2d(2) - if downsample is not None and downsample != scale_factor: + if downsample is not None and downsample != 1.0: self.downsample = nn.Upsample(scale_factor=downsample, mode='bicubic', align_corners=True) else: self.downsample = nn.Identity() diff --git a/render.py b/render.py new file mode 100644 index 0000000..d4c5e20 --- /dev/null +++ b/render.py @@ -0,0 +1,182 @@ +import time +from threading import Thread +import data_loader as dl +import torch + +torch.backends.cudnn.benchmark = True +import numpy as np +from models import * +import utils +from tqdm import tqdm +import cv2 +from pytorch_unet import UNet, SRUnet, SimpleResNet +from queue import Queue + + +# from apex import amp + +def save_with_cv(pic, imname): + pic = dl.de_normalize(pic.squeeze(0)) + npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) * 255 + npimg = cv2.cvtColor(npimg, cv2.COLOR_BGR2RGB) + + cv2.imwrite(imname, npimg) + + +def write_to_video(pic, writer): + pic = dl.de_normalize(pic.squeeze(0)) + npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) * 255 + npimg = npimg.astype('uint8') + npimg = cv2.cvtColor(npimg, cv2.COLOR_BGR2RGB) + + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(npimg, '540p CRF 23 + bicubic', (50, 1030), font, 1, (10, 10, 10), 2, cv2.LINE_AA) + cv2.putText(npimg, 'SR-Unet (ours)', (1920 // 2 + 50, 1020), font, 1, (10, 10, 10), 2, cv2.LINE_AA) + + writer.write(npimg) + + +def get_padded_dim(H_x, W_x, border=0, mod=16): + modH, modW = H_x % (mod + border), W_x % (mod + border) + padW = ((mod + border) - modW) % (mod + border) + padH = ((mod + border) - modH) % (mod + border) + + new_H = H_x + padH + new_W = W_x + padW + + return new_H, new_W, padH, padW + + +def pad_input(x, padH, padW): + x = F.pad(x, [0, padW, 0, padH]) + return x + + +def cv2toTorch(im): + im = im / 255 + im = torch.Tensor(im).cuda() + im = im.permute(2, 0, 1).unsqueeze(0) + im = dl.normalize_img(im) + return im + + +def torchToCv2(pic, rescale_factor=1.0): + if rescale_factor != 1.0: + pic = F.interpolate(pic, scale_factor=rescale_factor, align_corners=True, mode='bicubic') + pic = dl.de_normalize(pic.squeeze(0)) + pic = pic.permute(1, 2, 0) * 255 + npimg = pic.byte().cpu().numpy() + npimg = cv2.cvtColor(npimg, cv2.COLOR_BGR2RGB) + return npimg + + +def blend_images(i1, i2): + w = i1.shape[-1] + w_4 = w // 4 + i1 = i1[:, :, :, w_4:w_4 * 3] + i2 = i2[:, :, :, w_4:w_4 * 3] + out = torch.cat([i1, i2], dim=3) + return out + + +if __name__ == '__main__': + args = utils.ARArgs() + enable_write_to_video = False + arch_name = args.ARCHITECTURE + dataset_upscale_factor = args.UPSCALE_FACTOR + + if arch_name == 'srunet': + model = SRUnet(3, residual=True, scale_factor=dataset_upscale_factor, n_filters=args.N_FILTERS, + downsample=args.DOWNSAMPLE, layer_multiplier=args.LAYER_MULTIPLIER) + elif arch_name == 'unet': + model = UNet(3, residual=True, scale_factor=dataset_upscale_factor, n_filters=args.N_FILTERS) + elif arch_name == 'srgan': + model = SRResNet() + elif arch_name == 'espcn': + model = SimpleResNet(n_filters=64, n_blocks=6) + else: + raise Exception("Unknown architecture. Select one between:", args.archs) + + model_path = args.MODEL_NAME + model.load_state_dict(torch.load(model_path)) + + model = model.cuda() + model.reparametrize() + + path = args.CLIPNAME + cap = cv2.VideoCapture(path) + reader = torchvision.io.VideoReader(path, "video") + + if enable_write_to_video: + fourcc = cv2.VideoWriter_fourcc(*'MP4V') + hr_video_writer = cv2.VideoWriter('rendered.mp4', fourcc, 30, (1920, 1080)) + + metadata = reader.get_metadata() + + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height_fix, width_fix, padH, padW = get_padded_dim(height, width) + + frame_queue = Queue(1) + out_queue = Queue(1) + + reader.seek(0) + + + def read_pic(cap, q): + count = 0 + start = time.time() + while True: + cv2_im = next(cap)['data'] # .cuda().float() + cv2_im = cv2_im.cuda().float() + + x = dl.normalize_img(cv2_im / 255.).unsqueeze(0) + + x_bicubic = torch.clip(F.interpolate(x, scale_factor=args.UPSCALE_FACTOR * args.DOWNSAMPLE, mode='bicubic'), + min=-1, max=1) + + x = F.pad(x, [0, padW, 0, padH]) + count += 1 + q.put((x, x_bicubic)) + + + def show_pic(cap, q): + while True: + out = q.get() + scale = 1 + cv2_out = torchToCv2(out, rescale_factor=scale) + cv2.imshow('rendering', cv2_out) + cv2.waitKey(1) + + + t1 = Thread(target=read_pic, args=(reader, frame_queue)).start() + t2 = Thread(target=show_pic, args=(cap, out_queue)).start() + target_fps = cap.get(cv2.CAP_PROP_FPS) + target_frametime = 1000 / target_fps + + model = model.eval() + with torch.no_grad(): + tqdm_ = tqdm(range(frame_count)) + for i in tqdm_: + t0 = time.time() + + x, x_bicubic = frame_queue.get() + out = model(x)[:, :, :int(height) * 2, :int(width) * 2] + + out_true = i // (target_fps * 3) % 2 == 0 + + if not args.SHOW_ONLY_HQ: + out = blend_images(x_bicubic, out) + out_queue.put(out) + frametime = time.time() - t0 + if frametime < target_frametime * 1e-3: + time.sleep(target_frametime * 1e-3 - frametime) + + if enable_write_to_video: + write_to_video(out, hr_video_writer) + if i == 30 * 10: + hr_video_writer.release() + print("Releasing video") + + tqdm_.set_description("frame time: {}; fps: {}; {}".format(frametime * 1e3, 1000 / frametime, out_true)) diff --git a/utils.py b/utils.py index 30d7f6c..e638ae2 100644 --- a/utils.py +++ b/utils.py @@ -23,7 +23,7 @@ def __init__(self, args=None): ap.add_argument("-e", "--epochs", type=int, default=80, help="Number of epochs you want to train the model.") ap.add_argument("--clipname", type=str, default="", - help="Optimize the network specifically for a clip") + help="[RENDER.PY ONLY] path to the clip you want to upscale") ap.add_argument("--arch", type=str, default="srunet", choices=archs, help="Which network architecture to train.") ap.add_argument("--w0", type=float, default=1.0, @@ -34,14 +34,16 @@ def __init__(self, args=None): help="Adversarial Component Weight") ap.add_argument("--upscale", type=int, default=2, help="Default upscale factor, obbtained as resolution ratio between LQ and HQ samples") - ap.add_argument("--layermult", type=float, default=1.0, help="Layer multiplier - SR UNet only") - ap.add_argument("--nfilters", type=int, default=64, help="Net Number of filters param - SR UNet and UNet only") - ap.add_argument("--downsample", type=float, default=None, help="Downsample factor, SR Unet and UNet only") + ap.add_argument("--layer_mult", type=float, default=1.0, help="Layer multiplier - SR UNet only") + ap.add_argument("--n_filters", type=int, default=64, help="Net Number of filters param - SR UNet and UNet only") + ap.add_argument("--downsample", type=float, default=1.0, help="Downsample factor, SR Unet and UNet only") ap.add_argument("--testdir", type=str, default="test", help="[TEST ONLY] Where the test clips are contained.") ap.add_argument("--testinputres", type=int, default=540, help="[TEST ONLY] Input testing resolution") ap.add_argument("--testoutputres", type=int, default=1080, help="[TEST ONLY] Output testing resolution") ap.add_argument("--crf", type=int, default=23, help="Reference compression CRF") + ap.add_argument('--show-only-upscaled', dest='show-only-upscaled', action='store_true', + help="[RENDER.PY ONLY] If you want to show only the neural net upscaled version of the video") if args is None: args = vars(ap.parse_args()) @@ -62,13 +64,14 @@ def __init__(self, args=None): self.W1 = args['w1'] self.L0 = args['l0'] self.UPSCALE_FACTOR = args['upscale'] - self.LAYER_MULTIPLIER = args['layermult'] - self.N_FILTERS = args['nfilters'] + self.LAYER_MULTIPLIER = args['layer_mult'] + self.N_FILTERS = args['n_filters'] self.DOWNSAMPLE = args['downsample'] self.TEST_INPUT_RES = args['testinputres'] self.TEST_OUTPUT_RES = args['testoutputres'] self.CRF = args['crf'] self.TEST_DIR = args['testdir'] + self.SHOW_ONLY_HQ = args['show-only-upscaled'] self.archs = archs