-
Notifications
You must be signed in to change notification settings - Fork 3
feat: support most of features in ccvfi #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
3cc9b66
9e6bd8a
af16e04
61682b6
c60d0ae
0a84324
cebc239
ed11116
df63d06
ac3ddb8
acbfdff
8bf7626
5869332
9a36828
04d0107
3487d13
4340d1f
4f218ba
5a60e22
7722d28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,312 @@ | ||||||||
| # type: ignore | ||||||||
| import numpy as np | ||||||||
| import torch | ||||||||
| import torch.nn as nn | ||||||||
| import torch.nn.functional as F | ||||||||
|
|
||||||||
| from cccv.arch import ARCH_REGISTRY | ||||||||
| from cccv.arch.vfi.vfi_utils.warplayer import warp | ||||||||
| from cccv.type import ArchType | ||||||||
| from cccv.util.misc import distance_calculator | ||||||||
|
|
||||||||
|
|
||||||||
| @ARCH_REGISTRY.register(name=ArchType.DRBA) | ||||||||
| class DRBA(nn.Module): | ||||||||
| def __init__(self): | ||||||||
| super(DRBA, self).__init__() | ||||||||
| self.block0 = IFBlock(7 + 32, c=192) | ||||||||
| self.block1 = IFBlock(8 + 4 + 8 + 32, c=128) | ||||||||
| self.block2 = IFBlock(8 + 4 + 8 + 32, c=96) | ||||||||
| self.block3 = IFBlock(8 + 4 + 8 + 32, c=64) | ||||||||
| self.block4 = IFBlock(8 + 4 + 8 + 32, c=32) | ||||||||
| self.encode = Head() | ||||||||
|
|
||||||||
| support_cupy = True | ||||||||
| try: | ||||||||
| import cupy | ||||||||
|
|
||||||||
| if cupy.cuda.get_cuda_path() is None: | ||||||||
| support_cupy = False | ||||||||
| except Exception: | ||||||||
| support_cupy = False | ||||||||
|
|
||||||||
| if support_cupy: | ||||||||
| from cccv.arch.vfi.vfi_utils.softsplat import softsplat as fwarp | ||||||||
| else: | ||||||||
| from cccv.arch.vfi.vfi_utils.softsplat_torch import softsplat as fwarp | ||||||||
|
|
||||||||
| self.fwarp = fwarp | ||||||||
|
|
||||||||
| def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=False, f0=None, f1=None): | ||||||||
| if scale_list is None: | ||||||||
| scale_list = [16, 8, 4, 2, 1] | ||||||||
| channel = x.shape[1] // 2 | ||||||||
| img0 = x[:, :channel] | ||||||||
| img1 = x[:, channel:] | ||||||||
| if not torch.is_tensor(timestep): | ||||||||
| timestep = (x[:, :1].clone() * 0 + 1) * timestep | ||||||||
| f0 = self.encode(img0[:, :3]) if f0 is None else f0 | ||||||||
| f1 = self.encode(img1[:, :3]) if f1 is None else f1 | ||||||||
| flow_list = [] | ||||||||
| merged = [] | ||||||||
| mask_list = [] | ||||||||
| warped_img0 = img0 | ||||||||
| warped_img1 = img1 | ||||||||
| flow = None | ||||||||
| mask = None | ||||||||
| block = [self.block0, self.block1, self.block2, self.block3, self.block4] | ||||||||
| for i in range(5): | ||||||||
| if flow is None: | ||||||||
| flow, mask, feat = block[i]( | ||||||||
| torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i] | ||||||||
| ) | ||||||||
| if ensemble: | ||||||||
| print("warning: ensemble is not supported since RIFEv4.21") | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| else: | ||||||||
| wf0 = warp(f0, flow[:, :2]) | ||||||||
| wf1 = warp(f1, flow[:, 2:4]) | ||||||||
| fd, m0, feat = block[i]( | ||||||||
| torch.cat((warped_img0[:, :3], warped_img1[:, :3], wf0, wf1, timestep, mask, feat), 1), | ||||||||
| flow, | ||||||||
| scale=scale_list[i], | ||||||||
| ) | ||||||||
| if ensemble: | ||||||||
| print("warning: ensemble is not supported since RIFEv4.21") | ||||||||
| else: | ||||||||
| mask = m0 | ||||||||
| flow = flow + fd | ||||||||
| mask_list.append(mask) | ||||||||
| flow_list.append(flow) | ||||||||
| warped_img0 = warp(img0, flow[:, :2]) | ||||||||
| warped_img1 = warp(img1, flow[:, 2:4]) | ||||||||
| merged.append((warped_img0, warped_img1)) | ||||||||
| mask = torch.sigmoid(mask) | ||||||||
| merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) | ||||||||
| if not fastmode: | ||||||||
| print("contextnet is removed") | ||||||||
| """ | ||||||||
| c0 = self.contextnet(img0, flow[:, :2]) | ||||||||
| c1 = self.contextnet(img1, flow[:, 2:4]) | ||||||||
| tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) | ||||||||
| res = tmp[:, :3] * 2 - 1 | ||||||||
| merged[4] = torch.clamp(merged[4] + res, 0, 1) | ||||||||
| """ | ||||||||
| return merged[4], flow_list | ||||||||
|
|
||||||||
| def calc_flow(self, a, b, scale, f0=None, f1=None): | ||||||||
| scale_list = [16 / scale, 8 / scale, 4 / scale, 2 / scale, 1 / scale] | ||||||||
| # calc flow at the lowest resolution (significantly faster with almost no quality loss). | ||||||||
| timestep = (a[:, :1].clone() * 0 + 1) * 0.5 | ||||||||
| f0 = self.encode(a[:, :3]) if f0 is None else f0 | ||||||||
| f1 = self.encode(b[:, :3]) if f1 is None else f1 | ||||||||
| flow, _, _ = self.block0(torch.cat((a[:, :3], b[:, :3], f0, f1, timestep), 1), None, scale=scale_list[0]) | ||||||||
|
|
||||||||
| # get flow flow0.5 -> 0/1 | ||||||||
| flow50, flow51 = flow[:, :2], flow[:, 2:] | ||||||||
|
|
||||||||
| warp_method = "avg" | ||||||||
|
|
||||||||
| # qvi | ||||||||
| # flow05, norm2 = fwarp(flow50, flow50) | ||||||||
| # flow05[norm2]... | ||||||||
| # flow05 = -flow05 | ||||||||
|
|
||||||||
| flow05 = -1 * self.fwarp(flow50, flow50, None, warp_method) | ||||||||
| flow15 = -1 * self.fwarp(flow51, flow51, None, warp_method) | ||||||||
|
|
||||||||
| ones_mask = flow05.clone() * 0 + 1 | ||||||||
| mask05 = self.fwarp(ones_mask, flow50, None, warp_method) | ||||||||
| mask15 = self.fwarp(ones_mask, flow51, None, warp_method) | ||||||||
|
|
||||||||
| gap05 = mask05 < 0.999 | ||||||||
| gap15 = mask15 < 0.999 | ||||||||
|
|
||||||||
| flow05[gap05] = (ones_mask * max(flow05.shape[2], flow05.shape[3]))[gap05] | ||||||||
| flow15[gap15] = (ones_mask * max(flow15.shape[2], flow15.shape[3]))[gap15] | ||||||||
|
|
||||||||
| flow01 = flow05 * 2 | ||||||||
| flow10 = flow15 * 2 | ||||||||
|
|
||||||||
| return flow01, flow10, f0, f1 | ||||||||
|
|
||||||||
| def forward(self, x, minus_t, zero_t, plus_t, _left_scene, _right_scene, _scale, _reuse=None): | ||||||||
| _I0, _I1, _I2 = x[:, 0], x[:, 1], x[:, 2] | ||||||||
| flow10, flow01, f1, f0 = self.calc_flow(_I1, _I0, _scale) if not _reuse else _reuse | ||||||||
| if _reuse is None: | ||||||||
| flow12, flow21, f1, f2 = self.calc_flow(_I1, _I2, _scale) | ||||||||
| else: | ||||||||
| flow12, flow21, f1, f2 = self.calc_flow(_I1, _I2, _scale, f0=_reuse[2]) | ||||||||
|
|
||||||||
| # Compute the distance using the optical flow and distance calculator | ||||||||
| d10 = distance_calculator(flow10) + 1e-4 | ||||||||
| d12 = distance_calculator(flow12) + 1e-4 | ||||||||
|
|
||||||||
| # Calculate the distance ratio map | ||||||||
| drm10 = d10 / (d10 + d12) | ||||||||
| drm12 = d12 / (d10 + d12) | ||||||||
|
|
||||||||
| ones_mask = torch.ones_like(drm10, device=drm10.device) | ||||||||
|
|
||||||||
| def calc_drm_rife(_t): | ||||||||
| # The distance ratio map (drm) is initially aligned with I1. | ||||||||
| # To align it with I0 and I2, we need to warp the drm maps. | ||||||||
| # Note: 1. To reverse the direction of the drm map, use 1 - drm and then warp it. | ||||||||
| # 2. For RIFE, drm should be aligned with the time corresponding to the intermediate frame. | ||||||||
| _drm01r = self.fwarp(1 - drm10, flow10 * ((1 - drm10) * 2) * _t, None, strMode="avg") | ||||||||
| _drm21r = self.fwarp(1 - drm12, flow12 * ((1 - drm12) * 2) * _t, None, strMode="avg") | ||||||||
|
|
||||||||
| self.warped_ones_mask01r = self.fwarp(ones_mask, flow10 * ((1 - drm10) * 2) * _t, None, strMode="avg") | ||||||||
| self.warped_ones_mask21r = self.fwarp(ones_mask, flow12 * ((1 - drm12) * 2) * _t, None, strMode="avg") | ||||||||
|
|
||||||||
| holes01r = self.warped_ones_mask01r < 0.999 | ||||||||
| holes21r = self.warped_ones_mask21r < 0.999 | ||||||||
|
|
||||||||
| _drm01r[holes01r] = _drm01r[holes01r] | ||||||||
| _drm21r[holes21r] = _drm21r[holes21r] | ||||||||
|
Comment on lines
+164
to
+165
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
|
|
||||||||
| return _drm01r, _drm21r | ||||||||
|
|
||||||||
| output1, output2 = [], [] | ||||||||
|
|
||||||||
| if _left_scene: | ||||||||
| for i in range(len(minus_t)): | ||||||||
| minus_t[i] = -1 | ||||||||
|
|
||||||||
| if _right_scene: | ||||||||
| for _ in plus_t: | ||||||||
| zero_t = np.append(zero_t, 0) | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Comment on lines
+176
to
+177
|
||||||||
| for _ in plus_t: | |
| zero_t = np.append(zero_t, 0) | |
| zero_t = np.concatenate([zero_t, np.zeros(len(plus_t), dtype=zero_t.dtype)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line assigns a generator object to drm01r and drm21r instead of the tensor values. This will cause an error in subsequent operations. You should use a tuple or list to unpack the values correctly.
| drm01r, drm21r = (ones_mask.clone() * 0.5 for _ in range(2)) | |
| drm01r, drm21r = ones_mask.clone() * 0.5, ones_mask.clone() * 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
# type: ignoreat the file level disables all type checking for this file, which is generally discouraged. It would be better to add type hints and fix any type errors, or use more specific# type: ignore[error-code]comments on the lines that have issues. This comment also applies to other new files in this PR.