diff --git a/viz/renderer.py b/viz/renderer.py index 26e4f11..cf854e0 100644 --- a/viz/renderer.py +++ b/viz/renderer.py @@ -353,11 +353,11 @@ def _render_drag_impl(self, res, distance = ((xx.to(self._device) - point[0])**2 + (yy.to(self._device) - point[1])**2)**0.5 relis, reljs = torch.where(distance < round(r1 / 512 * h)) direction = direction / (torch.linalg.norm(direction) + 1e-7) - gridh = (relis-direction[1]) / (h-1) * 2 - 1 - gridw = (reljs-direction[0]) / (w-1) * 2 - 1 + gridh = (relis+direction[1]) / (h-1) * 2 - 1 + gridw = (reljs+direction[0]) / (w-1) * 2 - 1 grid = torch.stack([gridw,gridh], dim=-1).unsqueeze(0).unsqueeze(0) target = F.grid_sample(feat_resize.float(), grid, align_corners=True).squeeze(2) - loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs], target.detach()) + loss_motion += F.l1_loss(feat_resize[:,:,relis,reljs].detach(), target) loss = loss_motion if mask is not None: