diff --git a/Makefile b/Makefile index 5d3fbbe..dcbd718 100644 --- a/Makefile +++ b/Makefile @@ -16,4 +16,9 @@ build: .PHONY: vs vs: rm -f encoded.mkv - vspipe -c y4m example/vapoursynth.py - | ffmpeg -i - -vcodec libx265 -crf 16 encoded.mkv + vspipe -c y4m example/sr_vs.py - | ffmpeg -i - -vcodec libx264 encoded.mp4 + +.PHONY: dev +dev: + docker compose -f cccv-docker-compose.yml down + docker compose -f cccv-docker-compose.yml up -d diff --git a/README.md b/README.md index ab2fba2..f63bc63 100644 --- a/README.md +++ b/README.md @@ -30,9 +30,7 @@ import numpy as np from cccv import AutoModel, ConfigType, SRBaseModel -model: SRBaseModel = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, -) +model: SRBaseModel = AutoModel.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x) img = cv2.imdecode(np.fromfile("test.jpg", dtype=np.uint8), cv2.IMREAD_COLOR) img = model.inference_image(img) @@ -47,10 +45,10 @@ a simple example to use the VapourSynth to process a video import vapoursynth as vs from vapoursynth import core -from cccv import AutoModel, VSRBaseModel, ConfigType +from cccv import AutoModel, ConfigType, SRBaseModel -model: VSRBaseModel = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, +model: SRBaseModel = AutoModel.from_pretrained( + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, tile=None ) @@ -61,7 +59,13 @@ clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) clip.set_output() ``` -See more examples in the [example](./example) directory, cccv can register custom configurations and models to extend the functionality +See more examples in the [example](./example) directory, including: + +- SISR (Single Image Super-Resolution) +- VSR (Video Super-Resolution) +- VFI (Video Frame Interpolation) + +cccv can register custom configurations and models to extend the functionality ### Current Support diff --git a/assets/vfi/test_i0.jpg b/assets/vfi/test_i0.jpg new file mode 100644 index 0000000..5a64674 Binary files /dev/null and b/assets/vfi/test_i0.jpg differ diff --git a/assets/vfi/test_i1.jpg b/assets/vfi/test_i1.jpg new file mode 100644 index 0000000..ab122dd Binary files /dev/null and b/assets/vfi/test_i1.jpg differ diff --git a/assets/vfi/test_i2.jpg b/assets/vfi/test_i2.jpg new file mode 100644 index 0000000..eba2628 Binary files /dev/null and b/assets/vfi/test_i2.jpg differ diff --git a/assets/vfi/test_out_drba_0.jpg b/assets/vfi/test_out_drba_0.jpg new file mode 100644 index 0000000..71d7bef Binary files /dev/null and b/assets/vfi/test_out_drba_0.jpg differ diff --git a/assets/vfi/test_out_drba_1.jpg b/assets/vfi/test_out_drba_1.jpg new file mode 100644 index 0000000..9f8376a Binary files /dev/null and b/assets/vfi/test_out_drba_1.jpg differ diff --git a/assets/vfi/test_out_drba_2.jpg b/assets/vfi/test_out_drba_2.jpg new file mode 100644 index 0000000..c789e02 Binary files /dev/null and b/assets/vfi/test_out_drba_2.jpg differ diff --git a/assets/vfi/test_out_drba_3.jpg b/assets/vfi/test_out_drba_3.jpg new file mode 100644 index 0000000..8af8147 Binary files /dev/null and b/assets/vfi/test_out_drba_3.jpg differ diff --git a/assets/vfi/test_out_drba_4.jpg b/assets/vfi/test_out_drba_4.jpg new file mode 100644 index 0000000..469a8aa Binary files /dev/null and b/assets/vfi/test_out_drba_4.jpg differ diff --git a/assets/vfi/test_out_rife.jpg b/assets/vfi/test_out_rife.jpg new file mode 100644 index 0000000..9f7c6cb Binary files /dev/null and b/assets/vfi/test_out_rife.jpg differ diff --git a/cccv-docker-compose.yml b/cccv-docker-compose.yml new file mode 100644 index 0000000..1baa77e --- /dev/null +++ b/cccv-docker-compose.yml @@ -0,0 +1,33 @@ +version: "3.8" + +name: cccv + +networks: + backend: + driver: bridge + +services: + playground-cuda: + image: lychee0/vs-playground:cuda-dev + restart: always + ports: + - "1145:8888" + - "1022:22" + volumes: + - ./:/cccv + environment: + - JUPYTER_TOKEN=114514 + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: + - "0" + capabilities: + - gpu + - utility + - compute + - video + networks: + - backend diff --git a/cccv/__init__.py b/cccv/__init__.py index 6d0250a..7a602b2 100644 --- a/cccv/__init__.py +++ b/cccv/__init__.py @@ -26,6 +26,6 @@ from cccv.arch import ARCH_REGISTRY from cccv.auto import AutoConfig, AutoModel -from cccv.config import CONFIG_REGISTRY, BaseConfig, SRBaseConfig, VSRBaseConfig -from cccv.model import MODEL_REGISTRY, AuxiliaryBaseModel, CCBaseModel, SRBaseModel, VSRBaseModel +from cccv.config import CONFIG_REGISTRY, BaseConfig, SRBaseConfig, VFIBaseConfig, VSRBaseConfig +from cccv.model import MODEL_REGISTRY, AuxiliaryBaseModel, CCBaseModel, SRBaseModel, VFIBaseModel, VSRBaseModel from cccv.type import ArchType, BaseModelInterface, ConfigType, ModelType diff --git a/cccv/arch/__init__.py b/cccv/arch/__init__.py index e6c11bb..d29c1ba 100644 --- a/cccv/arch/__init__.py +++ b/cccv/arch/__init__.py @@ -23,3 +23,7 @@ from cccv.arch.vsr.edvr_arch import EDVR from cccv.arch.vsr.msrswvsr_arch import MSRSWVSR + +# Video Frame Interpolation +from cccv.arch.vfi.ifnet_arch import IFNet +from cccv.arch.vfi.drba_arch import DRBA diff --git a/cccv/arch/vfi/__init__.py b/cccv/arch/vfi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cccv/arch/vfi/drba_arch.py b/cccv/arch/vfi/drba_arch.py new file mode 100644 index 0000000..a4c27ee --- /dev/null +++ b/cccv/arch/vfi/drba_arch.py @@ -0,0 +1,312 @@ +# type: ignore +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cccv.arch import ARCH_REGISTRY +from cccv.arch.vfi.vfi_utils.warplayer import warp +from cccv.type import ArchType +from cccv.util.misc import distance_calculator + + +@ARCH_REGISTRY.register(name=ArchType.DRBA) +class DRBA(nn.Module): + def __init__(self): + super(DRBA, self).__init__() + self.block0 = IFBlock(7 + 32, c=192) + self.block1 = IFBlock(8 + 4 + 8 + 32, c=128) + self.block2 = IFBlock(8 + 4 + 8 + 32, c=96) + self.block3 = IFBlock(8 + 4 + 8 + 32, c=64) + self.block4 = IFBlock(8 + 4 + 8 + 32, c=32) + self.encode = Head() + + support_cupy = True + try: + import cupy + + if cupy.cuda.get_cuda_path() is None: + support_cupy = False + except Exception: + support_cupy = False + + if support_cupy: + from cccv.arch.vfi.vfi_utils.softsplat import softsplat as fwarp + else: + from cccv.arch.vfi.vfi_utils.softsplat_torch import softsplat as fwarp + + self.fwarp = fwarp + + def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=False, f0=None, f1=None): + if scale_list is None: + scale_list = [16, 8, 4, 2, 1] + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + f0 = self.encode(img0[:, :3]) if f0 is None else f0 + f1 = self.encode(img1[:, :3]) if f1 is None else f1 + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3, self.block4] + for i in range(5): + if flow is None: + flow, mask, feat = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i] + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block[i]( + torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask, feat), 1), + flow, + scale=scale_list[i], + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) + if not fastmode: + print("contextnet is removed") + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[4] = torch.clamp(merged[4] + res, 0, 1) + """ + return merged[4], flow_list + + def calc_flow(self, a, b, scale, f0=None, f1=None): + scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] + # calc flow at the lowest resolution (significantly faster with almost no quality loss). + timestep = (a[:, :1].clone() * 0 + 1) * 0.5 + f0 = self.encode(a[:, :3]) if f0 is None else f0 + f1 = self.encode(b[:, :3]) if f1 is None else f1 + flow, _, _ = self.block0(torch.cat((a[:, :3], b[:, :3], f0, f1, timestep), 1), None, scale=scale_list[0]) + + # get flow flow0.5 -> 0/1 + flow50, flow51 = flow[:, :2], flow[:, 2:] + + warp_method = "avg" + + # qvi + # flow05, norm2 = fwarp(flow50, flow50) + # flow05[norm2]... + # flow05 = -flow05 + + flow05 = -1 * self.fwarp(flow50, flow50, None, warp_method) + flow15 = -1 * self.fwarp(flow51, flow51, None, warp_method) + + ones_mask = flow05.clone() * 0 + 1 + mask05 = self.fwarp(ones_mask, flow50, None, warp_method) + mask15 = self.fwarp(ones_mask, flow51, None, warp_method) + + gap05 = mask05 < 0.999 + gap15 = mask15 < 0.999 + + flow05[gap05] = (ones_mask * max(flow05.shape[2], flow05.shape[3]))[gap05] + flow15[gap15] = (ones_mask * max(flow15.shape[2], flow15.shape[3]))[gap15] + + flow01 = flow05 * 2 + flow10 = flow15 * 2 + + return flow01, flow10, f0, f1 + + def forward(self, x, minus_t, zero_t, plus_t, _left_scene, _right_scene, _scale, _reuse=None): + _I0, _I1, _I2 = x[:, 0], x[:, 1], x[:, 2] + flow10, flow01, f1, f0 = self.calc_flow(_I1, _I0, _scale) if not _reuse else _reuse + if _reuse is None: + flow12, flow21, f1, f2 = self.calc_flow(_I1, _I2, _scale) + else: + flow12, flow21, f1, f2 = self.calc_flow(_I1, _I2, _scale, f0=_reuse[2]) + + # Compute the distance using the optical flow and distance calculator + d10 = distance_calculator(flow10) + 1e-4 + d12 = distance_calculator(flow12) + 1e-4 + + # Calculate the distance ratio map + drm10 = d10 / (d10 + d12) + drm12 = d12 / (d10 + d12) + + ones_mask = torch.ones_like(drm10, device=drm10.device) + + def calc_drm_rife(_t): + # The distance ratio map (drm) is initially aligned with I1. + # To align it with I0 and I2, we need to warp the drm maps. + # Note: 1. To reverse the direction of the drm map, use 1 - drm and then warp it. + # 2. For RIFE, drm should be aligned with the time corresponding to the intermediate frame. + _drm01r = self.fwarp(1 - drm10, flow10 * ((1 - drm10) * 2) * _t, None, strMode="avg") + _drm21r = self.fwarp(1 - drm12, flow12 * ((1 - drm12) * 2) * _t, None, strMode="avg") + + self.warped_ones_mask01r = self.fwarp(ones_mask, flow10 * ((1 - drm10) * 2) * _t, None, strMode="avg") + self.warped_ones_mask21r = self.fwarp(ones_mask, flow12 * ((1 - drm12) * 2) * _t, None, strMode="avg") + + holes01r = self.warped_ones_mask01r < 0.999 + holes21r = self.warped_ones_mask21r < 0.999 + + _drm01r[holes01r] = _drm01r[holes01r] + _drm21r[holes21r] = _drm21r[holes21r] + + return _drm01r, _drm21r + + output1, output2 = [], [] + + if _left_scene: + for i in range(len(minus_t)): + minus_t[i] = -1 + + if _right_scene: + for _ in plus_t: + zero_t = np.append(zero_t, 0) + plus_t = [] + + disable_drm = False + if (_left_scene and not _right_scene) or (not _left_scene and _right_scene): + drm01r, drm21r = (ones_mask.clone() * 0.5 for _ in range(2)) + drm01r = torch.nn.functional.interpolate(drm01r, size=_I0.shape[2:], mode="bilinear", align_corners=False) + drm21r = torch.nn.functional.interpolate(drm21r, size=_I0.shape[2:], mode="bilinear", align_corners=False) + disable_drm = True + + for t in minus_t: + t = -t + if t == 1: + output1.append(_I0) + continue + if not disable_drm: + drm01r, _ = calc_drm_rife(t) + output1.append( + self.inference( + torch.cat((_I1, _I0), 1), + timestep=t * (2 * drm01r), + scale_list=[16 / _scale, 8 / _scale, 4 / _scale, 2 / _scale, 1 / _scale], + )[0] + ) + for _ in zero_t: + output1.append(_I1) + for t in plus_t: + if t == 1: + output2.append(_I2) + continue + if not disable_drm: + _, drm21r = calc_drm_rife(t) + output2.append( + self.inference( + torch.cat((_I1, _I2), 1), + timestep=t * (2 * drm21r), + scale_list=[16 / _scale, 8 / _scale, 4 / _scale, 2 / _scale, 1 / _scale], + )[0] + ) + + _output = output1 + output2 + + # next flow10, flow01 = reverse(current flow12, flow21) + return _output, (flow21, flow12, f2, f1) + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True), + ) + + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 16, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat diff --git a/cccv/arch/vfi/ifnet_arch.py b/cccv/arch/vfi/ifnet_arch.py new file mode 100644 index 0000000..3d7ab51 --- /dev/null +++ b/cccv/arch/vfi/ifnet_arch.py @@ -0,0 +1,168 @@ +# type: ignore +import torch +import torch.nn as nn +import torch.nn.functional as F + +from cccv.arch import ARCH_REGISTRY +from cccv.arch.vfi.vfi_utils.warplayer import warp +from cccv.type import ArchType + + +@ARCH_REGISTRY.register(name=ArchType.IFNET) +class IFNet(nn.Module): + def __init__(self): + super(IFNet, self).__init__() + self.block0 = IFBlock(7 + 32, c=192) + self.block1 = IFBlock(8 + 4 + 8 + 32, c=128) + self.block2 = IFBlock(8 + 4 + 8 + 32, c=96) + self.block3 = IFBlock(8 + 4 + 8 + 32, c=64) + self.block4 = IFBlock(8 + 4 + 8 + 32, c=32) + self.encode = Head() + + def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=False): + if scale_list is None: + scale_list = [16, 8, 4, 2, 1] + channel = x.shape[1] // 2 + img0 = x[:, :channel] + img1 = x[:, channel:] + if not torch.is_tensor(timestep): + timestep = (x[:, :1].clone() * 0 + 1) * timestep + f0 = self.encode(img0[:, :3]) + f1 = self.encode(img1[:, :3]) + flow_list = [] + merged = [] + mask_list = [] + warped_img0 = img0 + warped_img1 = img1 + flow = None + mask = None + block = [self.block0, self.block1, self.block2, self.block3, self.block4] + for i in range(5): + if flow is None: + flow, mask, feat = block[i]( + torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i] + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, m0, feat = block[i]( + torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask, feat), 1), + flow, + scale=scale_list[i], + ) + if ensemble: + print("warning: ensemble is not supported since RIFEv4.21") + else: + mask = m0 + flow = flow + fd + mask_list.append(mask) + flow_list.append(flow) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + merged.append((warped_img0, warped_img1)) + mask = torch.sigmoid(mask) + merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) + if not fastmode: + print("contextnet is removed") + """ + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[4] = torch.clamp(merged[4] + res, 0, 1) + """ + return merged[4] + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True + ), + nn.LeakyReLU(0.2, True), + ) + + +def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False, + ), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, True), + ) + + +class Head(nn.Module): + def __init__(self): + super(Head, self).__init__() + self.cnn0 = nn.Conv2d(3, 16, 3, 2, 1) + self.cnn1 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn2 = nn.Conv2d(16, 16, 3, 1, 1) + self.cnn3 = nn.ConvTranspose2d(16, 16, 4, 2, 1) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x, feat=False): + x0 = self.cnn0(x) + x = self.relu(x0) + x1 = self.cnn1(x) + x = self.relu(x1) + x2 = self.cnn2(x) + x = self.relu(x2) + x3 = self.cnn3(x) + if feat: + return [x0, x1, x2, x3] + return x3 + + +class ResConv(nn.Module): + def __init__(self, c, dilation=1): + super(ResConv, self).__init__() + self.conv = nn.Conv2d(c, c, 3, 1, dilation, dilation=dilation, groups=1) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1)), requires_grad=True) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(self.conv(x) * self.beta + x) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64): + super(IFBlock, self).__init__() + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + self.convblock = nn.Sequential( + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ResConv(c), + ) + self.lastconv = nn.Sequential(nn.ConvTranspose2d(c, 4 * 13, 4, 2, 1), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False) * 1.0 / scale + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear", align_corners=False) + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat diff --git a/cccv/arch/vfi/vfi_utils/__init__.py b/cccv/arch/vfi/vfi_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cccv/arch/vfi/vfi_utils/softsplat.py b/cccv/arch/vfi/vfi_utils/softsplat.py new file mode 100644 index 0000000..5f82e98 --- /dev/null +++ b/cccv/arch/vfi/vfi_utils/softsplat.py @@ -0,0 +1,614 @@ +# type: ignore + +import collections +import os +import re +import typing + +import cupy +import torch + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): + if "device" not in objCudacache: + objCudacache["device"] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif isinstance(objValue, int): + strKey += str(objValue) + + elif isinstance(objValue, float): + strKey += str(objValue) + + elif isinstance(objValue, bool): + strKey += str(objValue) + + elif isinstance(objValue, str): + strKey += objValue + + elif isinstance(objValue, torch.Tensor): + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + + # end + # end + + strKey += objCudacache["device"] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif isinstance(objValue, int): + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif isinstance(objValue, float): + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif isinstance(objValue, bool): + strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) + + elif isinstance(objValue, str): + strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.uint8: + strKernel = strKernel.replace("{{type}}", "unsigned char") + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.float16: + strKernel = strKernel.replace("{{type}}", "half") + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.float32: + strKernel = strKernel.replace("{{type}}", "float") + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.float64: + strKernel = strKernel.replace("{{type}}", "double") + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.int32: + strKernel = strKernel.replace("{{type}}", "int") + + elif isinstance(objValue, torch.Tensor) and objValue.dtype == torch.int64: + strKernel = strKernel.replace("{{type}}", "long") + + elif isinstance(objValue, torch.Tensor): + print(strVariable, objValue.dtype) + + elif True: + print(strVariable, type(objValue)) + + # end + # end + + while True: + objMatch = re.search(r"(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace( + objMatch.group(), + str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) is False else intSizes[intArg].item()), + ) + # end + + while True: + objMatch = re.search(r"(OFFSET_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) is False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "OFFSET_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", "(" + str.join("+", strIndex) + ")" + ) + # end + + while True: + objMatch = re.search(r"(VALUE_)([0-4])(\()", strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == "(" else 0 + intParentheses -= 1 if strKernel[intStop] == ")" else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(",") + + assert intArgs == len(strArgs) - 1 + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append( + "((" + + strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() + + ")*" + + str( + intStrides[intArg] + if torch.is_tensor(intStrides[intArg]) is False + else intStrides[intArg].item() + ) + + ")" + ) + # end + + strKernel = strKernel.replace( + "VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", + strTensor + "[" + str.join("+", strIndex) + "]", + ) + # end + + objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} + # end + + return strKey + + +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey: str): + if "CUDA_HOME" not in os.environ: + os.environ["CUDA_HOME"] = cupy.cuda.get_cuda_path() + # end + + return cupy.RawModule( + code=objCudacache[strKey]["strKernel"], + options=("-I " + os.environ.get("CUDA_HOME"), "-I " + os.environ.get("CUDA_HOME") + "/include"), + ).get_function(objCudacache[strKey]["strFunction"]) + + +# end + + +########################################################## + + +def softsplat(tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str): + output_dtype = tenIn.dtype + + tenIn, tenFlow = tenIn.float(), tenFlow.float() + + if tenMetric is not None: + tenMetric = tenMetric.float() + + assert strMode.split("-")[0] in ["sum", "avg", "linear", "soft"] + + if strMode == "sum": + assert tenMetric is None + # if strMode == 'avg': assert (tenMetric is None) + if strMode.split("-")[0] == "linear": + assert tenMetric is not None + if strMode.split("-")[0] == "soft": + assert tenMetric is not None + + if strMode == "avg": + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split("-")[0] == "linear": + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split("-")[0] == "soft": + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split("-")[0] in ["avg", "linear", "soft"]: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split("-")) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "addeps": + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split("-")[1] == "zeroeps": + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split("-")[1] == "clipeps": + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut.to(output_dtype) + + +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda is True: + cuda_launch( + cuda_kernel( + "softsplat_out", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + """, + {"tenIn": tenIn, "tenFlow": tenFlow, "tenOut": tenOut}, + ) + )( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), # noqa + block=tuple([512, 1, 1]), # noqa + args=(cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()), + stream=collections.namedtuple("Stream", "ptr")(torch.cuda.current_stream().cuda_stream), + ) + + elif tenIn.is_cuda is not True: + pass + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous() + assert tenOutgrad.is_cuda is True + + tenIngrad = ( + tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + if self.needs_input_grad[0] is True + else None + ) + tenFlowgrad = ( + tenFlow.new_zeros([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) + if self.needs_input_grad[1] is True + else None + ) + + if tenIngrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_ingrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), # noqa + block=tuple([512, 1, 1]), # noqa + args=( + cuda_int32(tenIngrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), + None, + ), + stream=collections.namedtuple("Stream", "ptr")(torch.cuda.current_stream().cuda_stream), + ) + # end + + if tenFlowgrad is not None: + cuda_launch( + cuda_kernel( + "softsplat_flowgrad", + """ + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + """, + { + "tenIn": tenIn, + "tenFlow": tenFlow, + "tenOutgrad": tenOutgrad, + "tenIngrad": tenIngrad, + "tenFlowgrad": tenFlowgrad, + }, + ) + )( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), # noqa + block=tuple([512, 1, 1]), # noqa + args=( + cuda_int32(tenFlowgrad.nelement()), + tenIn.data_ptr(), + tenFlow.data_ptr(), + tenOutgrad.data_ptr(), + None, + tenFlowgrad.data_ptr(), + ), + stream=collections.namedtuple("Stream", "ptr")(torch.cuda.current_stream().cuda_stream), + ) + # end + + return tenIngrad, tenFlowgrad + + # end + + +# end diff --git a/cccv/arch/vfi/vfi_utils/softsplat_torch.py b/cccv/arch/vfi/vfi_utils/softsplat_torch.py new file mode 100644 index 0000000..142cee5 --- /dev/null +++ b/cccv/arch/vfi/vfi_utils/softsplat_torch.py @@ -0,0 +1,183 @@ +# type: ignore +# torch fallback for softsplat inference +# https://github.com/98mxr/GMFSS_Fortuna/pull/11/files +# author: TNTwise +# https://github.com/TNTwise + +import torch + +########################################################## + +grid_cache = {} +batch_cache = {} +torch.set_float32_matmul_precision("medium") +torch.set_grad_enabled(False) + + +########################################################## + + +@torch.inference_mode() +def softsplat(tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str): + mode_parts = strMode.split("-") + mode_main = mode_parts[0] + mode_sub = mode_parts[1] if len(mode_parts) > 1 else None + + assert mode_main in ["sum", "avg", "linear", "soft"] + if mode_main in ["sum", "avg"]: + assert tenMetric is None + if mode_main in ["linear", "soft"]: + assert tenMetric is not None + + mode_to_operation = { + "avg": lambda: torch.cat( + [ + tenIn, + tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]]), + ], + 1, + ), + "linear": lambda: torch.cat([tenIn * tenMetric, tenMetric], 1), + "soft": lambda: torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1), + } + + if mode_main in mode_to_operation: + tenIn = mode_to_operation[mode_main]() + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if mode_main in ["avg", "linear", "soft"]: + tenNormalize = tenOut[:, -1:, :, :] + + normalize_modes = { + None: lambda x: x + 0.0000001, + "addeps": lambda x: x + 0.0000001, + "zeroeps": lambda x: torch.where(x == 0.0, torch.tensor(1.0, device=x.device), x), + "clipeps": lambda x: x.clip(0.0000001, None), + } + + if mode_sub in normalize_modes: + tenNormalize = normalize_modes[mode_sub](tenNormalize) + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + + return tenOut + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.inference_mode() + def forward(ctx, tenIn, tenFlow): + """ + Forward pass of the Softsplat function. + + Parameters: + tenIn (torch.Tensor): Input tensor of shape [N, C, H, W] + tenFlow (torch.Tensor): Flow tensor of shape [N, 2, H, W] + + Returns: + torch.Tensor: Output tensor of shape [N, C, H, W] + """ + N, C, H, W = tenIn.size() + device = tenIn.device + origdtype = tenIn.dtype + + # Initialize output tensor + tenOut = torch.zeros_like(tenIn) + + key = (H, W, device, origdtype) + if key not in grid_cache: + # Create meshgrid of pixel coordinates + gridY, gridX = torch.meshgrid( + torch.arange(H, device=device, dtype=origdtype), + torch.arange(W, device=device, dtype=origdtype), + indexing="ij", + ) # [H, W] + # Cache the grids + grid_cache[key] = ( + gridY.unsqueeze(0).unsqueeze(0).expand(N, 1, H, W), + gridX.unsqueeze(0).unsqueeze(0).expand(N, 1, H, W), + ) + + if key not in batch_cache: + batch_cache[key] = torch.arange(N, device=device).view(N, 1, 1).expand(N, H, W).reshape(-1) + + gridY, gridX = grid_cache[key] + batch_indices = batch_cache[key] + + # Compute fltX and fltY + fltX = gridX + tenFlow[:, 0:1, :, :] + fltY = gridY + tenFlow[:, 1:2, :, :] + + # Flatten variables + fltX_flat = fltX.reshape(-1) + fltY_flat = fltY.reshape(-1) + tenIn_flat = tenIn.permute(0, 2, 3, 1).reshape(-1, C) + + # Finite mask + finite_mask = torch.isfinite(fltX_flat) & torch.isfinite(fltY_flat) + if not finite_mask.any(): + return tenOut + + fltX_flat = fltX_flat[finite_mask] + fltY_flat = fltY_flat[finite_mask] + tenIn_flat = tenIn_flat[finite_mask] + batch_indices = batch_indices[finite_mask] + + # Compute integer positions + intNW_X = torch.floor(fltX_flat).to(dtype=torch.int32) + intNW_Y = torch.floor(fltY_flat).to(dtype=torch.int32) + intNE_X = intNW_X + 1 + intNE_Y = intNW_Y + intSW_X = intNW_X + intSW_Y = intNW_Y + 1 + intSE_X = intNW_X + 1 + intSE_Y = intNW_Y + 1 + + # Compute weights + fltNW = (intSE_X - fltX_flat) * (intSE_Y - fltY_flat) + fltNE = (fltX_flat - intSW_X) * (intSW_Y - fltY_flat) + fltSW = (intNE_X - fltX_flat) * (fltY_flat - intNE_Y) + fltSE = (fltX_flat - intNW_X) * (fltY_flat - intNW_Y) + + # Prepare output tensor flat + tenOut_flat = tenOut.permute(0, 2, 3, 1).reshape(-1, C) + + # Define positions and weights + positions = [ + (intNW_X, intNW_Y, fltNW), + (intNE_X, intNE_Y, fltNE), + (intSW_X, intSW_Y, fltSW), + (intSE_X, intSE_Y, fltSE), + ] + + H, W = int(H), int(W) + + for intX, intY, weight in positions: + # Valid indices within image bounds + valid_mask = (intX >= 0) & (intX < W) & (intY >= 0) & (intY < H) + if not valid_mask.any(): + continue + + idx_b = batch_indices[valid_mask] + idx_x = intX[valid_mask] + idx_y = intY[valid_mask] + w = weight[valid_mask] + vals = tenIn_flat[valid_mask] * w.unsqueeze(1) + + # Compute linear indices + idx_NHW = idx_b * H * W + idx_y * W + idx_x + + # Accumulate values using index_add_ + tenOut_flat.index_add_(0, idx_NHW, vals) + + # Reshape tenOut back to [N, C, H, W] + tenOut = tenOut_flat.view(N, H, W, C).permute(0, 3, 1, 2) + + return tenOut + + # end + + # end + + # end diff --git a/cccv/arch/vfi/vfi_utils/warplayer.py b/cccv/arch/vfi/vfi_utils/warplayer.py new file mode 100644 index 0000000..b2a75f2 --- /dev/null +++ b/cccv/arch/vfi/vfi_utils/warplayer.py @@ -0,0 +1,33 @@ +# type: ignore +import torch + +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device) + .view(1, 1, 1, tenFlow.shape[3]) + .expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) + ) + tenVertical = ( + torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device) + .view(1, 1, tenFlow.shape[2], 1) + .expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) + ) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(tenFlow.device) + + tenFlow = torch.cat( + [ + tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), + tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), + ], + 1, + ) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1).to(tenInput.dtype) + return torch.nn.functional.grid_sample( + input=tenInput, grid=g.to(tenInput.device), mode="bilinear", padding_mode="border", align_corners=True + ) diff --git a/cccv/config/__init__.py b/cccv/config/__init__.py index 4b0f156..283a196 100644 --- a/cccv/config/__init__.py +++ b/cccv/config/__init__.py @@ -3,7 +3,7 @@ CONFIG_REGISTRY: RegistryConfigInstance = RegistryConfigInstance("CONFIG") -from cccv.config.base_config import BaseConfig, SRBaseConfig, VSRBaseConfig +from cccv.config.base_config import BaseConfig, SRBaseConfig, VSRBaseConfig, VFIBaseConfig # Auxiliary Network @@ -24,3 +24,7 @@ from cccv.config.vsr.edvr_config import EDVRConfig from cccv.config.vsr.animesr_config import AnimeSRConfig + +# Video Frame Interpolation +from cccv.config.vfi.rife_config import RIFEConfig +from cccv.config.vfi.drba_config import DRBAConfig diff --git a/cccv/config/auxnet/spynet_config.py b/cccv/config/auxnet/spynet_config.py index deda309..4210199 100644 --- a/cccv/config/auxnet/spynet_config.py +++ b/cccv/config/auxnet/spynet_config.py @@ -1,10 +1,11 @@ from typing import Union -from cccv.config import CONFIG_REGISTRY, BaseConfig +from cccv.config import CONFIG_REGISTRY +from cccv.config.base_config import AuxiliaryBaseConfig from cccv.type import ArchType, ConfigType, ModelType -class SpyNetConfig(BaseConfig): +class SpyNetConfig(AuxiliaryBaseConfig): arch: Union[ArchType, str] = ArchType.SPYNET model: Union[ModelType, str] = ModelType.SpyNet diff --git a/cccv/config/base_config.py b/cccv/config/base_config.py index da9b60a..e34b75f 100644 --- a/cccv/config/base_config.py +++ b/cccv/config/base_config.py @@ -15,9 +15,17 @@ class BaseConfig(BaseModel): model: Union[ModelType, str] +class AuxiliaryBaseConfig(BaseConfig): + pass + + class SRBaseConfig(BaseConfig): scale: int class VSRBaseConfig(SRBaseConfig): num_frame: int + + +class VFIBaseConfig(BaseConfig): + num_frame: int diff --git a/cccv/config/vfi/drba_config.py b/cccv/config/vfi/drba_config.py new file mode 100644 index 0000000..b41a438 --- /dev/null +++ b/cccv/config/vfi/drba_config.py @@ -0,0 +1,22 @@ +from typing import Union + +from cccv.config import CONFIG_REGISTRY +from cccv.config.base_config import VFIBaseConfig +from cccv.type import ArchType, ConfigType, ModelType + + +class DRBAConfig(VFIBaseConfig): + arch: Union[ArchType, str] = ArchType.DRBA + model: Union[ModelType, str] = ModelType.DRBA + num_frame: int = 3 + + +DRBAConfigs = [ + DRBAConfig( + name=ConfigType.DRBA_IFNet, + hash="4cc518e172156ad6207b9c7a43364f518832d83a4325d484240493a9e2980537", + ) +] + +for cfg in DRBAConfigs: + CONFIG_REGISTRY.register(cfg) diff --git a/cccv/config/vfi/rife_config.py b/cccv/config/vfi/rife_config.py new file mode 100644 index 0000000..a941d2b --- /dev/null +++ b/cccv/config/vfi/rife_config.py @@ -0,0 +1,22 @@ +from typing import Union + +from cccv.config import CONFIG_REGISTRY +from cccv.config.base_config import VFIBaseConfig +from cccv.type import ArchType, ConfigType, ModelType + + +class RIFEConfig(VFIBaseConfig): + arch: Union[ArchType, str] = ArchType.IFNET + model: Union[ModelType, str] = ModelType.RIFE + num_frame: int = 2 + + +RIFEConfigs = [ + RIFEConfig( + name=ConfigType.RIFE_IFNet_v426_heavy, + hash="4cc518e172156ad6207b9c7a43364f518832d83a4325d484240493a9e2980537", + ) +] + +for cfg in RIFEConfigs: + CONFIG_REGISTRY.register(cfg) diff --git a/cccv/model/__init__.py b/cccv/model/__init__.py index 69490ba..16763ab 100644 --- a/cccv/model/__init__.py +++ b/cccv/model/__init__.py @@ -8,6 +8,7 @@ from cccv.model.auxiliary_base_model import AuxiliaryBaseModel from cccv.model.sr_base_model import SRBaseModel from cccv.model.vsr_base_model import VSRBaseModel +from cccv.model.vfi_base_model import VFIBaseModel # Auxiliary Network @@ -20,3 +21,7 @@ # Video Super-Resolution from cccv.model.vsr.edvr_model import EDVRModel + +# Video Frame Interpolation +from cccv.model.vfi.rife_model import RIFEModel +from cccv.model.vfi.drba_model import DRBAModel diff --git a/cccv/model/base_model.py b/cccv/model/base_model.py index f42ba65..7c6b333 100644 --- a/cccv/model/base_model.py +++ b/cccv/model/base_model.py @@ -44,6 +44,9 @@ def __init__( # extra config self.one_frame_out: bool = False # for vsr model type + # load_state_dict + self.load_state_dict_strict: bool = True + # --- self.config = config self.device: Optional[torch.device] = device @@ -139,6 +142,12 @@ def load_model(self) -> Any: """ Auto load the model from config + These params in nn.Module.load_state_dict can be overridden in post_init_hook if needed: + + - self.load_state_dict_strict -> strict + + - self.load_state_dict_assign -> assign + :return: The initialized model with weights loaded """ cfg: BaseConfig = self.config @@ -155,7 +164,7 @@ def load_model(self) -> Any: # print(f"[CCCV] net_kw: {net_kw}") model = net(**net_kw) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=self.load_state_dict_strict) model.eval().to(self.device) return model diff --git a/cccv/model/vfi/__init__.py b/cccv/model/vfi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cccv/model/vfi/drba_model.py b/cccv/model/vfi/drba_model.py new file mode 100644 index 0000000..681776a --- /dev/null +++ b/cccv/model/vfi/drba_model.py @@ -0,0 +1,102 @@ +from typing import Any, List + +import cv2 +import numpy as np +import torch +from torch import Tensor +from torchvision import transforms + +from cccv.model import MODEL_REGISTRY +from cccv.model.vfi_base_model import VFIBaseModel +from cccv.type import ModelType +from cccv.util.misc import de_resize, resize + + +@MODEL_REGISTRY.register(name=ModelType.DRBA) +class DRBAModel(VFIBaseModel): + def post_init_hook(self) -> None: + self.load_state_dict_strict = False + + def transform_state_dict(self, state_dict: Any) -> Any: + def _convert(param: Any) -> Any: + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + + return _convert(state_dict) + + @torch.inference_mode() # type: ignore + def inference( + self, + imgs: torch.Tensor, + minus_t: list[float], + zero_t: list[float], + plus_t: list[float], + left_scene_change: bool, + right_scene_change: bool, + scale: float, + reuse: Any, + *args: Any, + **kwargs: Any, + ) -> tuple[Tensor, Any]: + """ + Inference with the model + + :param imgs: The input frames (B, 3, C, H, W) + :param minus_t: Timestep between -1 and 0 (I0 and I1) + :param zero_t: Timestep of 0, if not empty, preserve I1 (I1) + :param plus_t: Timestep between 0 and 1 (I1 and I2) + :param left_scene_change: True if there is a scene change between I0 and I1 (I0 and I1) + :param right_scene_change: True if there is a scene change between I1 and I2 (I1 and I2) + :param scale: Flow scale. + :param reuse: Reusable output from model with last frame pair. + + :return: All immediate frames between I0~I2 and reusable contents. + """ + + I0, I1, I2 = imgs[:, 0], imgs[:, 1], imgs[:, 2] + _, _, h, w = I0.shape + I0 = resize(I0, scale).unsqueeze(0) + I1 = resize(I1, scale).unsqueeze(0) + I2 = resize(I2, scale).unsqueeze(0) + + inp = torch.cat([I0, I1, I2], dim=1) + + results, reuse = self.model(inp, minus_t, zero_t, plus_t, left_scene_change, right_scene_change, scale, reuse) + + results = torch.cat(tuple(de_resize(result, h, w).unsqueeze(0) for result in results), dim=1) + + return results, reuse + + @torch.inference_mode() # type: ignore + def inference_image_list(self, img_list: List[np.ndarray], *args: Any, **kwargs: Any) -> List[np.ndarray]: + """ + Inference numpy image list with the model + + :param img_list: 3 input frames (img0, img1, img2) + + :return: 5 output frames (img0, img0_1, img1, img1_2, img2) + """ + if len(img_list) != 3: + raise ValueError("DRBA img_list must contain 3 images") + + new_img_list = [] + for img in img_list: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = transforms.ToTensor()(img).unsqueeze(0).to(self.device) + new_img_list.append(img) + + # b, n, c, h, w + img_tensor_stack = torch.stack(new_img_list, dim=1) + if self.fp16: + img_tensor_stack = img_tensor_stack.half() + + results, _ = self.inference(img_tensor_stack, [-1, -0.5], [0], [0.5, 1], False, False, 1.0, None) + + results_list = [] + for i in range(results.shape[1]): + img = results[0, i, :, :, :] + img = img.permute(1, 2, 0).cpu().numpy() + img = (img * 255).clip(0, 255).astype("uint8") + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + results_list.append(img) + + return results_list diff --git a/cccv/model/vfi/rife_model.py b/cccv/model/vfi/rife_model.py new file mode 100644 index 0000000..0e8605e --- /dev/null +++ b/cccv/model/vfi/rife_model.py @@ -0,0 +1,85 @@ +from typing import Any, List + +import cv2 +import numpy as np +import torch +from torchvision import transforms + +from cccv.model import MODEL_REGISTRY +from cccv.model.vfi_base_model import VFIBaseModel +from cccv.type import ModelType +from cccv.util.misc import de_resize, resize + + +@MODEL_REGISTRY.register(name=ModelType.RIFE) +class RIFEModel(VFIBaseModel): + def post_init_hook(self) -> None: + self.load_state_dict_strict = False + + def transform_state_dict(self, state_dict: Any) -> Any: + def _convert(param: Any) -> Any: + return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} + + return _convert(state_dict) + + @torch.inference_mode() # type: ignore + def inference(self, imgs: torch.Tensor, timestep: float, scale: float, *args: Any, **kwargs: Any) -> torch.Tensor: + """ + Inference with the model + + :param imgs: The input frames (B, 2, C, H, W) + :param timestep: Timestep between 0 and 1 (img0 and img1) + :param scale: Flow scale. + + :return: an immediate frame between I0 and I1 + """ + + I0, I1 = imgs[:, 0], imgs[:, 1] + _, _, h, w = I0.shape + I0 = resize(I0, scale) + I1 = resize(I1, scale) + + inp = torch.cat([I0, I1], dim=1) + scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] + + result = self.model(inp, timestep, scale_list) + + result = de_resize(result, h, w) + + return result + + @torch.inference_mode() # type: ignore + def inference_image_list(self, img_list: List[np.ndarray], *args: Any, **kwargs: Any) -> List[np.ndarray]: + """ + Inference numpy image list with the model + + :param img_list: 2 input frames (img0, img1) + + :return: 1 output frames (img0_1) + """ + if len(img_list) != 2: + raise ValueError("IFNet img_list must contain 2 images") + + new_img_list = [] + for img in img_list: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = transforms.ToTensor()(img).unsqueeze(0).to(self.device) + new_img_list.append(img) + + # b, n, c, h, w + img_tensor_stack = torch.stack(new_img_list, dim=1) + if self.fp16: + img_tensor_stack = img_tensor_stack.half() + + out = self.inference(img_tensor_stack, timestep=0.5, scale=1.0) + + # Convert to numpy image list + results_list = [] + + img = out.squeeze(0).permute(1, 2, 0).cpu().numpy() + img = (img * 255).clip(0, 255).astype("uint8") + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + results_list.append(img) + + return results_list diff --git a/cccv/model/vfi_base_model.py b/cccv/model/vfi_base_model.py new file mode 100644 index 0000000..a726775 --- /dev/null +++ b/cccv/model/vfi_base_model.py @@ -0,0 +1,55 @@ +from typing import Any, List + +import numpy as np +import torch + +from cccv.config import VFIBaseConfig +from cccv.model import MODEL_REGISTRY +from cccv.model.base_model import CCBaseModel +from cccv.type import ModelType + + +@MODEL_REGISTRY.register(name=ModelType.VFIBaseModel) +class VFIBaseModel(CCBaseModel): + def inference(self, *args: Any, **kwargs: Any) -> torch.Tensor: + raise NotImplementedError + + def inference_image_list(self, img_list: List[np.ndarray], *args: Any, **kwargs: Any) -> List[np.ndarray]: + raise NotImplementedError + + @torch.inference_mode() # type: ignore + def inference_video( + self, + clip: Any, + scale: float = 1.0, + tar_fps: float = 60, + scdet: bool = True, + scdet_threshold: float = 0.3, + *args: Any, + **kwargs: Any, + ) -> Any: + """ + Inference the video with the model, the clip should be a vapoursynth clip + + :param clip: vs.VideoNode + :param scale: The flow scale factor + :param tar_fps: The fps of the interpolated video + :param scdet: Enable SSIM scene change detection + :param scdet_threshold: SSIM scene change detection threshold (greater is sensitive) + :return: + """ + + from cccv.vs import inference_vfi + + cfg: VFIBaseConfig = self.config + + return inference_vfi( + inference=self.inference, + clip=clip, + scale=scale, + tar_fps=tar_fps, + num_frame=cfg.num_frame, + scdet=scdet, + scdet_threshold=scdet_threshold, + device=self.device, + ) diff --git a/cccv/model/vsr_base_model.py b/cccv/model/vsr_base_model.py index c16fc10..487e5a0 100644 --- a/cccv/model/vsr_base_model.py +++ b/cccv/model/vsr_base_model.py @@ -48,6 +48,8 @@ def inference_image_list(self, img_list: List[np.ndarray], *args: Any, **kwargs: # b, n, c, h, w img_tensor_stack = torch.stack(new_img_list, dim=1) + if self.fp16: + img_tensor_stack = img_tensor_stack.half() out = self.inference(img_tensor_stack) @@ -89,7 +91,7 @@ def inference_video(self, clip: Any, *args: Any, **kwargs: Any) -> Any: inference=self.inference, clip=clip, scale=cfg.scale, - length=cfg.num_frame, + num_frame=cfg.num_frame, device=self.device, one_frame_out=self.one_frame_out, ) diff --git a/cccv/type/arch.py b/cccv/type/arch.py index 112a000..2307645 100644 --- a/cccv/type/arch.py +++ b/cccv/type/arch.py @@ -23,3 +23,8 @@ class ArchType(str, Enum): EDVR = "EDVR" MSRSWVSR = "MSRSWVSR" + + # ------------------------------------- Video Frame Interpolation -------------------------------------------------- + + IFNET = "IFNET" + DRBA = "DRBA" diff --git a/cccv/type/config.py b/cccv/type/config.py index 2f32529..aca5fb0 100644 --- a/cccv/type/config.py +++ b/cccv/type/config.py @@ -108,3 +108,11 @@ class ConfigType(str, Enum): # AnimeSR AnimeSR_v1_PaperModel_4x = "AnimeSR_v1_PaperModel_4x.pth" AnimeSR_v2_4x = "AnimeSR_v2_4x.pth" + + # ------------------------------------- Video Frame Interpolation -------------------------------------------------- + + # RIFE + RIFE_IFNet_v426_heavy = "RIFE_IFNet_v426_heavy.pth" + + # DRBA + DRBA_IFNet = "DRBA_IFNet.pth" diff --git a/cccv/type/model.py b/cccv/type/model.py index a8e4c41..3db79a4 100644 --- a/cccv/type/model.py +++ b/cccv/type/model.py @@ -25,3 +25,9 @@ class ModelType(str, Enum): EDVR = "EDVR" AnimeSR = "AnimeSR" + + # ------------------------------------- Video Frame Interpolation -------------------------------------------------- + VFIBaseModel = "VFIBaseModel" + + RIFE = "RIFE" + DRBA = "DRBA" diff --git a/cccv/util/misc.py b/cccv/util/misc.py index e7cdf66..5d32094 100644 --- a/cccv/util/misc.py +++ b/cccv/util/misc.py @@ -1,7 +1,12 @@ +import math import random +from math import exp +from typing import Any import numpy as np import torch +import torch.nn.functional as F +from torch import Tensor def set_random_seed(seed: int = 0) -> None: @@ -14,3 +19,148 @@ def set_random_seed(seed: int = 0) -> None: torch.cuda.manual_seed_all(seed) except Exception: pass + + +def resize(img: Tensor, _scale: float) -> Tensor: + _, _, _h, _w = img.shape + while _h * _scale % 64 != 0: + _h += 1 + while _w * _scale % 64 != 0: + _w += 1 + return F.interpolate(img, size=(int(_h), int(_w)), mode="bilinear", align_corners=False) + + +def de_resize(img: Any, ori_h: int, ori_w: int) -> Tensor: + return F.interpolate(img, size=(int(ori_h), int(ori_w)), mode="bilinear", align_corners=False) + + +def distance_calculator(_x: Tensor) -> Tensor: + dtype = _x.dtype + u, v = _x[:, 0:1].float(), _x[:, 1:].float() + return torch.sqrt(u**2 + v**2).to(dtype) + + +class TMapper: + def __init__(self, src: float = -1.0, dst: float = 0.0, times: float = -1): + self.times = dst / src if times == -1 else times + self.now_step = -1 + self.src = src + self.dst = dst + + def get_range_timestamps( + self, _min: float, _max: float, lclose: bool = True, rclose: bool = False, normalize: bool = True + ) -> list: + _min_step = math.ceil(_min * self.times) + _max_step = math.ceil(_max * self.times) + _start = _min_step if lclose else _min_step + 1 + _end = _max_step if not rclose else _max_step + 1 + if _start >= _end: + return [] + if normalize: + return [((_i / self.times) - _min) / (_max - _min) for _i in range(_start, _end)] + return [_i / self.times for _i in range(_start, _end)] + + +def gaussian(window_size: int, sigma: float) -> Tensor: + gauss = Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window_3d(window_size: int, channel: int = 1) -> Tensor: + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous() + return window + + +def ssim_matlab( + img1: Tensor, + img2: Tensor, + window_size: int = 11, + window: Tensor = None, + size_average: bool = True, +) -> Tensor: + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + # Channel is set to 1 since we consider color images as volumetric images + window = create_window_3d(real_size, channel=1).to(img1.device).to(img1.dtype) + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode="replicate"), window, padding=padd, groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_sq + sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu2_sq + sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), "replicate"), window, padding=padd, groups=1) - mu1_mu2 + + C1 = (0.01 * L) ** 2 + C2 = (0.03 * L) ** 2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + return ret + + +def check_scene(x1: Tensor, x2: Tensor, enable_scdet: bool, scdet_threshold: float) -> bool: + """ + Check if the scene is different, based on the SSIM value of the two input tensors. + + Input Tensor can be 3D, 4D, or 5D. + + :param x1: The first input tensor. + :param x2: The second input tensor. + :param enable_scdet: Whether to enable the scene change detection. + :param scdet_threshold: The threshold of the SSIM value. + """ + + if not enable_scdet: + return False + if x1.dim() != x2.dim(): + raise ValueError("The dimensions of the two input tensors must be the same.") + if x1.dim() not in [3, 4, 5]: + raise ValueError("The input tensor must be 3D, 4D, or 5D.") + + _x1 = x1.clone() + _x2 = x2.clone() + + if _x1.dim() == 3: + _x1 = _x1.unsqueeze(0) + _x2 = _x2.unsqueeze(0) + + if _x1.dim() == 5: + _x1 = _x1.squeeze(0) + _x2 = _x2.squeeze(0) + + _x1 = F.interpolate(_x1, (32, 32), mode="bilinear", align_corners=False) + _x2 = F.interpolate(_x2, (32, 32), mode="bilinear", align_corners=False) + + return ssim_matlab(_x1, _x2).item() < scdet_threshold diff --git a/cccv/vs/__init__.py b/cccv/vs/__init__.py index 1c926c2..8d1f29b 100644 --- a/cccv/vs/__init__.py +++ b/cccv/vs/__init__.py @@ -34,4 +34,5 @@ """ from cccv.vs.sr import inference_sr +from cccv.vs.vfi import inference_vfi from cccv.vs.vsr import inference_vsr diff --git a/cccv/vs/vfi.py b/cccv/vs/vfi.py new file mode 100644 index 0000000..d3789b1 --- /dev/null +++ b/cccv/vs/vfi.py @@ -0,0 +1,266 @@ +import math +from typing import Callable, Dict + +import numpy as np +import torch +import vapoursynth as vs +from vapoursynth import core + +from cccv.util.misc import TMapper, check_scene +from cccv.vs.convert import frame_to_tensor, tensor_to_frame + + +def inference_vfi( + inference: Callable, + clip: vs.VideoNode, + scale: float, + tar_fps: float, + device: torch.device, + num_frame: int = 2, + scdet: bool = True, + scdet_threshold: float = 0.3, +) -> vs.VideoNode: + """ + Inference the video with the model, the clip should be a vapoursynth clip + + :param inference: The inference function + :param clip: vs.VideoNode + :param scale: The flow scale factor + :param tar_fps: The fps of the interpolated video + :param device: The device + :param num_frame: The input frame count of vfi method once infer + :param scdet: Enable SSIM scene change detection + :param scdet_threshold: SSIM scene change detection threshold (greater is sensitive) + :return: + """ + + if core.num_threads != 1: + raise ValueError("[CCCV] The number of threads must be 1 when enable frame interpolation") + + if clip.format.id not in [vs.RGBH, vs.RGBS]: + raise vs.Error("[CCCV] Only vs.RGBH and vs.RGBS formats are supported") + + if num_frame > clip.num_frames: + raise ValueError("[CCCV] Input frames should be less than the number of frames in the clip") + elif num_frame <= 1: + raise ValueError("[CCCV] Input frames should be greater than 1") + + src_fps = clip.fps.numerator / clip.fps.denominator + if src_fps > tar_fps: + raise ValueError("[CCCV] The target fps should be greater than the clip fps") + + if scale < 0 or not math.log2(scale).is_integer(): + raise ValueError("[CCCV] The scale should be greater than 0 and is power of two") + + vfi_methods = { + 2: inference_vfi_two_frame_in, + 3: inference_vfi_three_frame_in, + } + + if num_frame not in vfi_methods: + raise ValueError(f"[CCCV] The vfi method with {num_frame} frame input is not supported") + + mapper = TMapper(src_fps, tar_fps) + + return vfi_methods[num_frame](inference, clip, mapper, scale, scdet, scdet_threshold, device) + + +def inference_vfi_two_frame_in( + inference: Callable, + clip: vs.VideoNode, + mapper: TMapper, + scale: float, + scdet: bool, + scdet_threshold: float, + device: torch.device, +) -> vs.VideoNode: + """ + VFI for two frame input models + + f1, f2 -> f1?, f1t?, f2? + + For the two frame input model, the inference function should accept a tensor with shape (b, 2, c, h, w) + And return a tensor with shape (b, c, h, w) + + :param inference: The inference function + :param clip: vs.VideoNode + :param scale: The flow scale factor + :param mapper: The framerate mapper + :param scdet: Enable SSIM scene change detection + :param scdet_threshold: SSIM scene change detection threshold (greater is sensitive) + :param device: The device + :return: + """ + + in_idx: int = 0 + out_idx: int = 0 + in_frames: Dict[int, torch.Tensor] = {} + out_frames: Dict[int, torch.Tensor] = {} + flag_end: bool = False + reuse: tuple[torch.Tensor, ...] + + def to_input_tensor(x: vs.VideoFrame) -> torch.Tensor: + return frame_to_tensor(x, device=device).unsqueeze(0).unsqueeze(0) + + new_clip = clip.std.AssumeFPS(fpsnum=mapper.dst, fpsden=1) + less_num_frames = math.ceil(clip.num_frames * mapper.dst / mapper.src) - clip.num_frames + for _ in range(less_num_frames): + new_clip = new_clip.std.DuplicateFrames(clip.num_frames - 1) + + def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: + nonlocal in_idx, out_idx, in_frames, out_frames, flag_end, reuse + if n >= out_idx and not flag_end: + if in_idx not in in_frames.keys(): + in_frames[in_idx] = to_input_tensor(clip.get_frame(in_idx)) + I0 = in_frames[in_idx] + + if in_idx + 1 >= clip.num_frames - 1: + flag_end = True + return tensor_to_frame(out_frames[list(out_frames.keys())[-1]], f[1].copy()) + + if in_idx + 1 not in in_frames.keys(): + in_frames[in_idx + 1] = to_input_tensor(clip.get_frame(in_idx + 1)) + I1 = in_frames[in_idx + 1] + + ts = mapper.get_range_timestamps(in_idx, in_idx + 1, lclose=True, rclose=flag_end, normalize=True) + + scene = check_scene(I0, I1, scdet, scdet_threshold) + + for t in ts: + if scene: + out = I0.squeeze(0) + else: + if t == 0: + out = I0.squeeze(0) + elif t == 1: + out = I1.squeeze(0) + else: + out = inference(torch.cat([I0, I1], dim=1), timestep=t, scale=scale) + out_frames[out_idx] = out + out_idx += 1 + + # clear input cache + if in_idx - 1 in in_frames.keys(): + in_frames.pop(in_idx - 1) + + in_idx += 1 + + # clear output cache + if n - 1 in out_frames.keys() and len(out_frames.keys()) > 2: + out_frames.pop(n - 1) + + if n not in out_frames.keys(): + return tensor_to_frame(out_frames[list(out_frames.keys())[-1]], f[1].copy()) + + return tensor_to_frame(out_frames[n], f[1].copy()) + + return new_clip.std.ModifyFrame([new_clip, new_clip], _inference) + + +def inference_vfi_three_frame_in( + inference: Callable, + clip: vs.VideoNode, + mapper: TMapper, + scale: float, + scdet: bool, + scdet_threshold: float, + device: torch.device, +) -> vs.VideoNode: + """ + VFI for three frame input models + + f1, f2, f3 -> f1?, f1t?, f2?, f2t?, f3? + + For the three frame input model, the inference function should accept a tensor with shape (b, 3, c, h, w) + And return a tensor with shape (b, c, h, w) + + :param inference: The inference function + :param clip: vs.VideoNode + :param scale: The flow scale factor + :param mapper: The framerate mapper + :param scdet: Enable SSIM scene change detection + :param scdet_threshold: SSIM scene change detection threshold (greater is sensitive) + :param device: The device + :return: + """ + + in_idx: int = 0 + out_idx: int = 0 + in_frames: Dict[int, torch.Tensor] = {} + out_frames: Dict[int, torch.Tensor] = {} + flag_end: bool = False + reuse: tuple[torch.Tensor, ...] + + def calc_t(_mapper: TMapper, _idx: float, _flag_end: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + ts = _mapper.get_range_timestamps(_idx - 0.5, _idx + 0.5, lclose=True, rclose=_flag_end, normalize=False) + timestamp = np.asarray(ts, dtype=float) - _idx + vfi_timestamp = np.round(timestamp, 4) + + minus_t = vfi_timestamp[vfi_timestamp < 0] + zero_t = vfi_timestamp[vfi_timestamp == 0] + plus_t = vfi_timestamp[vfi_timestamp > 0] + return minus_t, zero_t, plus_t + + def to_input_tensor(x: vs.VideoFrame) -> torch.Tensor: + return frame_to_tensor(x, device=device).unsqueeze(0).unsqueeze(0) + + new_clip = clip.std.AssumeFPS(fpsnum=mapper.dst, fpsden=1) + less_num_frames = math.ceil(clip.num_frames * mapper.dst / mapper.src) - clip.num_frames + for _ in range(less_num_frames): + new_clip = new_clip.std.DuplicateFrames(clip.num_frames - 1) + + def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: + nonlocal in_idx, out_idx, in_frames, out_frames, flag_end, reuse + if n >= out_idx and not flag_end: + if in_idx not in in_frames.keys(): + in_frames[in_idx] = to_input_tensor(clip.get_frame(in_idx)) + I0 = in_frames[in_idx] + + if in_idx + 1 >= clip.num_frames - 1: + flag_end = True + return tensor_to_frame(out_frames[list(out_frames.keys())[-1]], f[1].copy()) + + if in_idx + 1 not in in_frames.keys(): + in_frames[in_idx + 1] = to_input_tensor(clip.get_frame(in_idx + 1)) + I1 = in_frames[in_idx + 1] + + if in_idx + 2 >= clip.num_frames - 1: + flag_end = True + else: + if in_idx + 2 not in in_frames.keys(): + in_frames[in_idx + 2] = to_input_tensor(clip.get_frame(in_idx + 2)) + I2 = in_frames[in_idx + 2] + + mt, zt, pt = calc_t(mapper, in_idx, flag_end) + left_scene = check_scene(I0, I1, scdet, scdet_threshold) + if in_idx == 0: # head + right_scene = left_scene + output, reuse = inference(torch.cat([I0, I0, I1], dim=1), mt, zt, pt, False, right_scene, scale, None) + elif flag_end: # tail + output, _ = inference(torch.cat([I0, I1, I1], dim=1), mt, zt, pt, left_scene, False, scale, reuse) + else: + right_scene = check_scene(I1, I2, scdet, scdet_threshold) + output, reuse = inference( + torch.cat([I0, I1, I2], dim=1), mt, zt, pt, left_scene, right_scene, scale, reuse + ) + + for i in range(output.shape[1]): + out_frames[out_idx] = output[0, i : i + 1] + out_idx += 1 + + # clear input cache + if in_idx - 1 in in_frames.keys(): + in_frames.pop(in_idx - 1) + + in_idx += 1 + + # clear output cache + if n - 1 in out_frames.keys() and len(out_frames.keys()) > 2: + out_frames.pop(n - 1) + + if n not in out_frames.keys(): + return tensor_to_frame(out_frames[list(out_frames.keys())[-1]], f[1].copy()) + + return tensor_to_frame(out_frames[n], f[1].copy()) + + return new_clip.std.ModifyFrame([new_clip, new_clip], _inference) diff --git a/cccv/vs/vsr.py b/cccv/vs/vsr.py index 437ffac..fad36c7 100644 --- a/cccv/vs/vsr.py +++ b/cccv/vs/vsr.py @@ -11,7 +11,7 @@ def inference_vsr( inference: Callable[[torch.Tensor], torch.Tensor], clip: vs.VideoNode, scale: Union[float, int, Any], - length: int, + num_frame: int, device: torch.device, one_frame_out: bool = False, ) -> vs.VideoNode: @@ -21,7 +21,7 @@ def inference_vsr( :param inference: The inference function :param clip: vs.VideoNode :param scale: The scale factor - :param length: The length of the input frames + :param num_frame: Number of input frames :param device: The device :param one_frame_out: Whether the model is one frame output model :return: @@ -29,23 +29,22 @@ def inference_vsr( if clip.format.id not in [vs.RGBH, vs.RGBS]: raise vs.Error("[CCCV] Only vs.RGBH and vs.RGBS formats are supported") - if length > clip.num_frames: - raise ValueError("[CCCV] The length of the input frames should be less than the number of frames in the clip") - - if length < 2: - raise ValueError("[CCCV] The length of the input frames should be greater than 1") + if num_frame > clip.num_frames: + raise ValueError("[CCCV] Input frames should be less than the number of frames in the clip") + elif num_frame <= 1: + raise ValueError("[CCCV] Input frames should be greater than 1") if not one_frame_out: - return inference_vsr_multi_frame_out(inference, clip, scale, length, device) + return inference_vsr_multi_frame_out(inference, clip, scale, num_frame, device) else: - return inference_vsr_one_frame_out(inference, clip, scale, length, device) + return inference_vsr_one_frame_out(inference, clip, scale, num_frame, device) def inference_vsr_multi_frame_out( inference: Callable[[torch.Tensor], torch.Tensor], clip: vs.VideoNode, scale: Union[float, int, Any], - length: int, + num_frame: int, device: torch.device, ) -> vs.VideoNode: """ @@ -59,7 +58,7 @@ def inference_vsr_multi_frame_out( :param inference: The inference function :param clip: vs.VideoNode :param scale: The scale factor - :param length: The length of the input frames + :param num_frame: Number of input frames :param device: The device :return: """ @@ -74,7 +73,7 @@ def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: cache.clear() img = [] - for i in range(length): + for i in range(num_frame): index = n + i if index >= clip.num_frames: img.append(frame_to_tensor(clip.get_frame(clip.num_frames - 1), device=device).unsqueeze(0)) @@ -101,7 +100,7 @@ def inference_vsr_one_frame_out( inference: Callable[[torch.Tensor], torch.Tensor], clip: vs.VideoNode, scale: Union[float, int, Any], - length: int, + num_frame: int, device: torch.device, ) -> vs.VideoNode: """ @@ -115,12 +114,12 @@ def inference_vsr_one_frame_out( :param inference: The inference function :param clip: vs.VideoNode :param scale: The scale factor - :param length: The length of the input frames, should be odd + :param num_frame: Number of input frames, should be odd :param device: The device :return: """ - if length % 2 == 0: + if num_frame % 2 == 0: raise ValueError("[CCCV] The length of the input frames should be odd") lock = Lock() @@ -128,8 +127,8 @@ def inference_vsr_one_frame_out( def _inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame: with lock: img = [] - for i in range(length): - index = i - length // 2 + n + for i in range(num_frame): + index = i - num_frame // 2 + n if index < 0: img.append(frame_to_tensor(clip.get_frame(0), device=device).unsqueeze(0)) diff --git a/example/auto.py b/example/auto.py index 54fc8d0..bca274f 100644 --- a/example/auto.py +++ b/example/auto.py @@ -2,26 +2,24 @@ # --- sisr, use fp16 to inference -model: BaseModelInterface = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x -) +model: BaseModelInterface = AutoModel.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x) # --- use fp32 to inference model = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, fp16=False, ) # --- vsr -model = AutoModel.from_pretrained(pretrained_model_name=ConfigType.AnimeSR_v2_4x) +model = AutoModel.from_pretrained(ConfigType.AnimeSR_v2_4x) # --- torch.compile model = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, compile=True, # compile_backend="inductor", ) @@ -29,6 +27,6 @@ # --- disable tile processing model = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, tile=None, ) diff --git a/example/register.py b/example/register.py index d525184..7cb8245 100644 --- a/example/register.py +++ b/example/register.py @@ -1,7 +1,6 @@ from typing import Any -from cccv import AutoConfig, AutoModel, SRBaseModel -from cccv.arch import RRDBNet +from cccv import ArchType, AutoConfig, AutoModel, SRBaseModel from cccv.config import RealESRGANConfig # define your own config name and model name @@ -12,47 +11,22 @@ # extend from cccv.BaseConfig then implement your own config parameters cfg = RealESRGANConfig( name=cfg_name, - url="https://github.com/EutropicAI/cccv/releases/download/model_zoo/RealESRGAN_RealESRGAN_x4plus_anime_6B_4x.pth", - arch="RRDB", + url="https://github.com/EutropicAI/cccv/releases/download/model_zoo/RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth", + arch=ArchType.SRVGG, model=model_name, - scale=4, - num_block=6, + scale=2, ) AutoConfig.register(cfg) -# this should be your own model -# extend from cccv.SRBaseModel or cccv.VSRBaseModel then implement your own model -# self.one_frame_out: bool = False for this kind of vsr model: f1, f2, f3, f4 -> f1', f2', f3', f4' -# self.one_frame_out: bool = True for this kind of vsr model: f-2, f-1, f0, f1, f2 -> f0' -# override self.one_frame_out in self.load_model() if you want +# extend from cccv.SRBaseModel then implement your own model @AutoModel.register(name=model_name) class TESTMODEL(SRBaseModel): def load_model(self) -> Any: - cfg: RealESRGANConfig = self.config - state_dict = self.get_state_dict() - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - elif "model_state_dict" in state_dict: - # For APISR's model - state_dict = state_dict["model_state_dict"] - - model = RRDBNet( - num_in_ch=cfg.num_in_ch, - num_out_ch=cfg.num_out_ch, - scale=cfg.scale, - num_feat=cfg.num_feat, - num_block=cfg.num_block, - num_grow_ch=cfg.num_grow_ch, - ) - - model.load_state_dict(state_dict) - model.eval().to(self.device) - return model + print("Override load_model function here") + print("We use default load_model function to load the model") + return super().load_model() model: TESTMODEL = AutoModel.from_pretrained(cfg_name) diff --git a/example/sisr.py b/example/sisr.py index fdf6215..12eb129 100644 --- a/example/sisr.py +++ b/example/sisr.py @@ -11,9 +11,7 @@ model: SRBaseModel = AutoModel.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x) elif example == 1: # edit the configuration - config: BaseConfig = AutoConfig.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x - ) + config: BaseConfig = AutoConfig.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x) print(config) config.scale = 2 model: SRBaseModel = AutoModel.from_config(config=config) @@ -30,7 +28,7 @@ elif example == 4: # use custom model dir and gh proxy model: SRBaseModel = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, model_dir="./", gh_proxy="https://github.abskoop.workers.dev/", ) diff --git a/example/vapoursynth.py b/example/sr_vs.py similarity index 70% rename from example/vapoursynth.py rename to example/sr_vs.py index bb8885f..3ebd105 100644 --- a/example/vapoursynth.py +++ b/example/sr_vs.py @@ -13,11 +13,9 @@ if example == 0: # --- sisr, use fp16 to inference (vs.RGBH) - model: CCBaseModel = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, tile=None - ) + model: CCBaseModel = AutoModel.from_pretrained(ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, tile=None) - clip = core.bs.VideoSource(source="s.mp4") + clip = core.bs.VideoSource(source="s.mkv") clip = core.resize.Bicubic(clip=clip, matrix_in_s="709", format=vs.RGBH) clip = model.inference_video(clip) clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) @@ -27,10 +25,10 @@ # --- use fp32 to inference (vs.RGBS) model: CCBaseModel = AutoModel.from_pretrained( - pretrained_model_name=ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, fp16=False, tile=None + ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x, fp16=False, tile=None ) - clip = core.bs.VideoSource(source="s.mp4") + clip = core.bs.VideoSource(source="s.mkv") clip = core.resize.Bicubic(clip=clip, matrix_in_s="709", format=vs.RGBS) clip = model.inference_video(clip) clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) diff --git a/example/vfi_vs.py b/example/vfi_vs.py new file mode 100644 index 0000000..8faef00 --- /dev/null +++ b/example/vfi_vs.py @@ -0,0 +1,20 @@ +import sys + +sys.path.append(".") +sys.path.append("..") + +import vapoursynth as vs +from vapoursynth import core + +from cccv import AutoModel, ConfigType, VFIBaseModel + +# --- IFNet, use fp16 to inference (vs.RGBH) + +model: VFIBaseModel = AutoModel.from_pretrained(ConfigType.RIFE_IFNet_v426_heavy, fp16=True, tile=None) + +core.num_threads = 1 # should be set to single thread now, TODO: fix it +clip = core.bs.VideoSource(source="s.mkv") +clip = core.resize.Bicubic(clip=clip, matrix_in_s="709", format=vs.RGBH) +clip = model.inference_video(clip, scale=1.0, tar_fps=60, scdet=True, scdet_threshold=0.3) +clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) +clip.set_output() diff --git a/example/vsr.py b/example/vsr.py index 01854d4..b5f9881 100644 --- a/example/vsr.py +++ b/example/vsr.py @@ -7,7 +7,7 @@ img = cv2.imdecode(np.fromfile("../assets/test.jpg", dtype=np.uint8), cv2.IMREAD_COLOR) imgList = [img, img, img] -model: VSRBaseModel = AutoModel.from_pretrained(ConfigType.AnimeSR_v2_4x, fp16=False) +model: VSRBaseModel = AutoModel.from_pretrained(ConfigType.AnimeSR_v2_4x) imgOutList = model.inference_image_list(imgList) diff --git a/example/vsr_vs.py b/example/vsr_vs.py new file mode 100644 index 0000000..f13e5c9 --- /dev/null +++ b/example/vsr_vs.py @@ -0,0 +1,45 @@ +import sys + +sys.path.append(".") +sys.path.append("..") + +import vapoursynth as vs +from vapoursynth import core + +from cccv import AutoModel, CCBaseModel, ConfigType + +example = 0 + +if example == 0: + # VSR for multi frame output models + # + # f1, f2, f3, f4 -> f1', f2', f3', f4' + model: CCBaseModel = AutoModel.from_pretrained(ConfigType.AnimeSR_v2_4x, tile=None) + + clip = core.bs.VideoSource(source="s.mkv") + clip = core.resize.Bicubic(clip=clip, matrix_in_s="709", format=vs.RGBH) + clip = model.inference_video(clip) + clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) + clip.set_output() + +elif example == 1: + # VSR for one frame output models + # + # f-2, f-1, f0, f1, f2 -> f0' + # + # Should enable self.one_frame_out = True + # @MODEL_REGISTRY.register(name=ModelType.EDVR) + # class EDVRModel(VSRBaseModel): + # def post_init_hook(self) -> None: + # self.one_frame_out = True + + model: CCBaseModel = AutoModel.from_pretrained(ConfigType.EDVR_M_SR_REDS_official_4x, tile=(256, 256)) + + clip = core.bs.VideoSource(source="s.mkv") + clip = core.resize.Bicubic(clip=clip, matrix_in_s="709", format=vs.RGBH) + clip = model.inference_video(clip) + clip = core.resize.Bicubic(clip=clip, matrix_s="709", format=vs.YUV420P16) + clip.set_output() + +else: + raise NotImplementedError diff --git a/tests/sr/test_sr.py b/tests/sr/test_sr.py index ea2b36b..fcf836f 100644 --- a/tests/sr/test_sr.py +++ b/tests/sr/test_sr.py @@ -23,9 +23,7 @@ def test_inference() -> None: k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x - model: SRBaseModel = AutoModel.from_pretrained( - pretrained_model_name=k, device=CCCV_DEVICE, fp16=False, tile=CCCV_TILE - ) + model: SRBaseModel = AutoModel.from_pretrained(k, device=CCCV_DEVICE, fp16=False, tile=CCCV_TILE) t2 = model(tensor1) t3 = model.inference(tensor1) @@ -55,9 +53,7 @@ def test_sr_compile() -> None: img1 = load_image() k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x - model: SRBaseModel = AutoModel.from_pretrained( - pretrained_model_name=k, device=CCCV_DEVICE, fp16=CCCV_FP16, compile=True, tile=CCCV_TILE - ) + model: SRBaseModel = AutoModel.from_pretrained(k, device=CCCV_DEVICE, fp16=CCCV_FP16, compile=True, tile=CCCV_TILE) img2 = model.inference_image(img1) diff --git a/tests/test_auto_class.py b/tests/test_auto_class.py index 520174c..326c255 100644 --- a/tests/test_auto_class.py +++ b/tests/test_auto_class.py @@ -1,6 +1,6 @@ from typing import Any -from cccv import AutoConfig, AutoModel +from cccv import ArchType, AutoConfig, AutoModel from cccv.config import RealESRGANConfig from cccv.model import SRBaseModel @@ -11,11 +11,10 @@ def test_auto_class_register() -> None: cfg = RealESRGANConfig( name=cfg_name, - url="https://github.com/EutropicAI/cccv/releases/download/model_zoo/RealESRGAN_RealESRGAN_x4plus_anime_6B_4x.pth", - arch="RRDB", + url="https://github.com/EutropicAI/cccv/releases/download/model_zoo/RealESRGAN_AnimeJaNai_HD_V3_Compact_2x.pth", + arch=ArchType.SRVGG, model=model_name, - scale=4, - num_block=6, + scale=2, ) AutoConfig.register(cfg) diff --git a/tests/test_tile.py b/tests/test_tile.py index 04f4037..e0e4b6f 100644 --- a/tests/test_tile.py +++ b/tests/test_tile.py @@ -39,7 +39,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_auto_model() -> None: k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x - model: SRBaseModel = AutoModel.from_pretrained(pretrained_model_name=k, fp16=False, device=CCCV_DEVICE) + model: SRBaseModel = AutoModel.from_pretrained(k, fp16=False, device=CCCV_DEVICE) assert model.tile == (128, 128) assert model.tile_pad == 8 assert model.pad_img is None @@ -47,5 +47,5 @@ def test_auto_model() -> None: def test_auto_model_no_tile() -> None: k = ConfigType.RealESRGAN_AnimeJaNai_HD_V3_Compact_2x - model: SRBaseModel = AutoModel.from_pretrained(pretrained_model_name=k, fp16=False, device=CCCV_DEVICE, tile=None) + model: SRBaseModel = AutoModel.from_pretrained(k, fp16=False, device=CCCV_DEVICE, tile=None) assert model.tile is None diff --git a/tests/test_util.py b/tests/test_util.py index d7a553a..a767cef 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,6 +5,16 @@ from cccv.util.color import rgb_to_yuv, yuv_to_rgb from cccv.util.device import DEFAULT_DEVICE +from cccv.util.misc import ( + TMapper, + check_scene, + create_window_3d, + de_resize, + distance_calculator, + gaussian, + resize, + ssim_matlab, +) from .util import calculate_image_similarity, load_image @@ -38,3 +48,84 @@ def test_color() -> None: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) assert calculate_image_similarity(img, load_image()) + + +def test_resize() -> None: + img = torch.randn(1, 3, 64, 64) # 创建一个随机的 4D 张量 + scale = 0.5 + resized_img = resize(img, scale) + assert resized_img.shape[2] % 64 == 0 # 检查高度是否能被 64 整除 + assert resized_img.shape[3] % 64 == 0 # 检查宽度是否能被 64 整除 + + +def test_de_resize() -> None: + img = torch.randn(1, 3, 128, 128) + ori_h, ori_w = 64, 64 + de_resized_img = de_resize(img, ori_h, ori_w) + assert de_resized_img.shape[2] == ori_h + assert de_resized_img.shape[3] == ori_w + + +def test_distance_calculator() -> None: + x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) # 创建一个 4D 张量 + distance = distance_calculator(x) + expected_distance = torch.sqrt(torch.tensor([1.0**2 + 2.0**2, 3.0**2 + 4.0**2])) + assert torch.allclose(distance, expected_distance) + + +def test_TMapper() -> None: + mapper = TMapper(src=1.0, dst=2.0) + timestamps = mapper.get_range_timestamps(0.0, 1.0, normalize=True) + assert len(timestamps) > 0 + assert all(0.0 <= t <= 1.0 for t in timestamps) + + +def test_gaussian() -> None: + window_size = 5 + sigma = 1.5 + gauss = gaussian(window_size, sigma) + assert gauss.shape == (window_size,) + assert torch.allclose(gauss.sum(), torch.tensor(1.0)) + + +def test_create_window_3d() -> None: + window_size = 5 + channel = 1 + window = create_window_3d(window_size, channel) + assert window.shape == (1, channel, window_size, window_size, window_size) + + +def test_ssim_matlab() -> None: + img1 = torch.randn(1, 3, 64, 64) + img2 = torch.randn(1, 3, 64, 64) + ssim_value = ssim_matlab(img1, img2) + assert isinstance(ssim_value, torch.Tensor) + assert 0.0 <= ssim_value.item() <= 1.0 + + +class Test_Check_Scene: + def test_5d(self) -> None: + x1 = torch.randn(1, 1, 3, 64, 64) + x2 = torch.randn(1, 1, 3, 64, 64) + + # 测试 enable_scdet 为 False 的情况 + result = check_scene(x1, x2, enable_scdet=False, scdet_threshold=0.5) + assert result is False # 当 enable_scdet 为 False 时,应返回 False + + # 测试 enable_scdet 为 True 的情况 + result = check_scene(x1, x2, enable_scdet=True, scdet_threshold=0.5) + assert isinstance(result, bool) + + def test_4d(self) -> None: + x1 = torch.randn(1, 3, 64, 64) + x2 = torch.randn(1, 3, 64, 64) + + result = check_scene(x1, x2, enable_scdet=True, scdet_threshold=0.5) + assert isinstance(result, bool) + + def test_3d(self) -> None: + x1 = torch.randn(3, 64, 64) + x2 = torch.randn(3, 64, 64) + + result = check_scene(x1, x2, enable_scdet=True, scdet_threshold=0.5) + assert isinstance(result, bool) diff --git a/tests/util.py b/tests/util.py index cda6619..fd8a4b1 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,6 +1,7 @@ import math import os from pathlib import Path +from typing import List import cv2 import numpy as np @@ -13,19 +14,57 @@ torch_2_4: bool = torch.__version__.startswith("2.4") ASSETS_PATH = Path(__file__).resolve().parent.parent.absolute() / "assets" + +# normal test image TEST_IMG_PATH = ASSETS_PATH / "test.jpg" +# vfi test image +TEST_IMG_PATH_0 = ASSETS_PATH / "vfi" / "test_i0.jpg" +TEST_IMG_PATH_1 = ASSETS_PATH / "vfi" / "test_i1.jpg" +TEST_IMG_PATH_2 = ASSETS_PATH / "vfi" / "test_i2.jpg" + +EVAL_IMG_PATH_RIFE = ASSETS_PATH / "vfi" / "test_out_rife.jpg" + +EVAL_IMG_PATH_DRBA_0 = ASSETS_PATH / "vfi" / "test_out_drba_0.jpg" +EVAL_IMG_PATH_DRBA_1 = ASSETS_PATH / "vfi" / "test_out_drba_1.jpg" +EVAL_IMG_PATH_DRBA_2 = ASSETS_PATH / "vfi" / "test_out_drba_2.jpg" +EVAL_IMG_PATH_DRBA_3 = ASSETS_PATH / "vfi" / "test_out_drba_3.jpg" +EVAL_IMG_PATH_DRBA_4 = ASSETS_PATH / "vfi" / "test_out_drba_4.jpg" + CI_ENV = os.environ.get("GITHUB_ACTIONS") == "true" CCCV_FP16 = True if not CI_ENV else False CCCV_TILE = None if not CI_ENV else (64, 64) CCCV_DEVICE = DEFAULT_DEVICE if not CI_ENV else torch.device("cpu") -def load_image() -> np.ndarray: - img = cv2.imdecode(np.fromfile(str(TEST_IMG_PATH), dtype=np.uint8), cv2.IMREAD_COLOR) +# load normal test image +def load_image(img_path: Path = TEST_IMG_PATH) -> np.ndarray: + img = cv2.imdecode(np.fromfile(str(img_path), dtype=np.uint8), cv2.IMREAD_COLOR) return img +# load vfi test images +def load_images() -> List[np.ndarray]: + return [load_image(k) for k in [TEST_IMG_PATH_0, TEST_IMG_PATH_1, TEST_IMG_PATH_2]] + + +def load_eval_images() -> List[np.ndarray]: + return [ + load_image(k) + for k in [ + EVAL_IMG_PATH_DRBA_0, + EVAL_IMG_PATH_DRBA_1, + EVAL_IMG_PATH_DRBA_2, + EVAL_IMG_PATH_DRBA_3, + EVAL_IMG_PATH_DRBA_4, + ] + ] + + +def load_eval_image() -> np.ndarray: + return load_image(EVAL_IMG_PATH_RIFE) + + def calculate_image_similarity(image1: np.ndarray, image2: np.ndarray, similarity: float = 0.85) -> bool: """ calculate image similarity, check SR is correct diff --git a/tests/vfi/__init__.py b/tests/vfi/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/vfi/test_drba.py b/tests/vfi/test_drba.py new file mode 100644 index 0000000..6be45a2 --- /dev/null +++ b/tests/vfi/test_drba.py @@ -0,0 +1,32 @@ +import cv2 + +from cccv import AutoConfig, AutoModel, BaseConfig, ConfigType +from cccv.model import VFIBaseModel +from tests.util import ( + ASSETS_PATH, + CCCV_DEVICE, + CCCV_FP16, + CCCV_TILE, + calculate_image_similarity, + load_eval_images, + load_images, +) + + +class Test_DRBA: + def test_official(self) -> None: + img0, img1, img2 = load_images() + eval_imgs = load_eval_images() + + for k in [ConfigType.DRBA_IFNet]: + print(f"Testing {k}") + cfg: BaseConfig = AutoConfig.from_pretrained(k) + model: VFIBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=CCCV_FP16, tile=CCCV_TILE) + print(model.device) + + out = model.inference_image_list(img_list=[img0, img1, img2]) + + assert len(out) == 5 + for i in range(len(out)): + cv2.imwrite(str(ASSETS_PATH / f"test_{k}_{i}_out.jpg"), out[i]) + assert calculate_image_similarity(eval_imgs[i], out[i]) diff --git a/tests/vfi/test_rife.py b/tests/vfi/test_rife.py new file mode 100644 index 0000000..1942876 --- /dev/null +++ b/tests/vfi/test_rife.py @@ -0,0 +1,32 @@ +import cv2 + +from cccv import AutoConfig, AutoModel, BaseConfig, ConfigType +from cccv.model import VFIBaseModel +from tests.util import ( + ASSETS_PATH, + CCCV_DEVICE, + CCCV_FP16, + CCCV_TILE, + calculate_image_similarity, + load_eval_image, + load_images, +) + + +class Test_RIFE: + def test_official(self) -> None: + img0, img1, _ = load_images() + eval_img = load_eval_image() + + for k in [ConfigType.RIFE_IFNet_v426_heavy]: + print(f"Testing {k}") + cfg: BaseConfig = AutoConfig.from_pretrained(k) + model: VFIBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=CCCV_FP16, tile=CCCV_TILE) + print(model.device) + + out = model.inference_image_list(img_list=[img0, img1]) + + assert len(out) == 1 + for i in range(len(out)): + cv2.imwrite(str(ASSETS_PATH / f"test_{k}_{i}_out.jpg"), out[i]) + assert calculate_image_similarity(eval_img, out[i]) diff --git a/tests/vsr/test_animesr.py b/tests/vsr/test_animesr.py index d3f5793..8e0c722 100644 --- a/tests/vsr/test_animesr.py +++ b/tests/vsr/test_animesr.py @@ -5,7 +5,7 @@ from tests.util import ( ASSETS_PATH, CCCV_DEVICE, - CCCV_TILE, + CCCV_FP16, calculate_image_similarity, compare_image_size, load_image, @@ -20,7 +20,7 @@ def test_official(self) -> None: for k in [ConfigType.AnimeSR_v1_PaperModel_4x, ConfigType.AnimeSR_v2_4x]: print(f"Testing {k}") cfg: BaseConfig = AutoConfig.from_pretrained(k) - model: VSRBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=False, tile=CCCV_TILE) + model: VSRBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=CCCV_FP16) print(model.device) imgOutList = model.inference_image_list(imgList) diff --git a/tests/vsr/test_edvr.py b/tests/vsr/test_edvr.py index 000f339..20ff2cc 100644 --- a/tests/vsr/test_edvr.py +++ b/tests/vsr/test_edvr.py @@ -2,13 +2,7 @@ from cccv import AutoConfig, AutoModel, BaseConfig, ConfigType from cccv.model import VSRBaseModel -from tests.util import ( - ASSETS_PATH, - CCCV_DEVICE, - calculate_image_similarity, - compare_image_size, - load_image, -) +from tests.util import ASSETS_PATH, CCCV_DEVICE, CCCV_FP16, calculate_image_similarity, compare_image_size, load_image class Test_EDVR: @@ -19,7 +13,7 @@ def test_official_M(self) -> None: for k in [ConfigType.EDVR_M_SR_REDS_official_4x, ConfigType.EDVR_M_woTSA_SR_REDS_official_4x]: print(f"Testing {k}") cfg: BaseConfig = AutoConfig.from_pretrained(k) - model: VSRBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=False) + model: VSRBaseModel = AutoModel.from_config(config=cfg, device=CCCV_DEVICE, fp16=CCCV_FP16) print(model.device) img2 = model.inference_image_list(imgList)