From 6a4465720a119ef89a34a949de1ddb2a246b342e Mon Sep 17 00:00:00 2001 From: Ethan Blackwood Date: Mon, 27 Oct 2025 00:06:29 -0400 Subject: [PATCH] Take indices into account when applying piecewise shifts --- caiman/motion_correction.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/caiman/motion_correction.py b/caiman/motion_correction.py index 9036e2e4a..4f3977817 100644 --- a/caiman/motion_correction.py +++ b/caiman/motion_correction.py @@ -456,16 +456,28 @@ def apply_shifts_movie(self, fname, rigid_shifts: Optional[bool] = None, save_me sh[0], sh[1]), 0, is_freq=False, border_nan=self.border_nan) for img, sh in zip( Y, self.shifts_rig)] else: - # take potential upsampling into account when recreating patch grid - dims = Y.shape[1:] - patch_centers = get_patch_centers(dims, overlaps=self.overlaps, strides=self.strides, - shifts_opencv=self.shifts_opencv, upsample_factor_grid=self.upsample_factor_grid) + # take indices and potential upsampling into account when recreating patch grid + dims = Y[0][self.indices].shape + patch_centers_orig = get_patch_centers( + dims, overlaps=self.overlaps, strides=self.strides, + shifts_opencv=self.shifts_opencv, upsample_factor_grid=self.upsample_factor_grid) + + # if only a portion of the original image was used, offset/multiply patch centers to now apply to the whole movie + patch_centers = tuple([ + list(dim_inds.start + dim_inds.step * np.array(dim_centers_orig)) + for dim_inds, dim_centers_orig in zip(self.indices, patch_centers_orig) + ]) + + # force shifts_interpolate if there was any cropping - easier than making a special-case path that + # resizes the shifts and extrapolates to the border but doesn't fully take patch centers into account + shifts_interpolate = True if any(dim_inds != slice(None) for dim_inds in self.indices) else self.shifts_interpolate + if self.is3D: # x_shifts_els and y_shifts_els are switched intentionally m_reg = [ apply_pw_shifts_remap_3d(img, shifts_y=-x_shifts, shifts_x=-y_shifts, shifts_z=-z_shifts, patch_centers=patch_centers, border_nan=self.border_nan, - shifts_interpolate=self.shifts_interpolate) + shifts_interpolate=shifts_interpolate) for img, x_shifts, y_shifts, z_shifts in zip(Y, self.x_shifts_els, self.y_shifts_els, self.z_shifts_els) ] @@ -473,7 +485,7 @@ def apply_shifts_movie(self, fname, rigid_shifts: Optional[bool] = None, save_me # x_shifts_els and y_shifts_els are switched intentionally m_reg = [ apply_pw_shifts_remap_2d(img, shifts_y=-x_shifts, shifts_x=-y_shifts, patch_centers=patch_centers, - border_nan=self.border_nan, shifts_interpolate=self.shifts_interpolate) + border_nan=self.border_nan, shifts_interpolate=shifts_interpolate) for img, x_shifts, y_shifts in zip(Y, self.x_shifts_els, self.y_shifts_els) ]