Skip to content

Commit 86596a7

Browse files
azrael417bonevbs
authored andcommitted
Tkurth/resample fix (#56)
* fixing resample precision issues * fixing resample part 2 * adding comment
1 parent 8bc53a6 commit 86596a7

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

torch_harmonics/resample.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,15 @@ def extra_repr(self):
121121
return f"in_shape={(self.nlat_in, self.nlon_in)}, out_shape={(self.nlat_out, self.nlon_out)}"
122122

123123
def _upscale_longitudes(self, x: torch.Tensor):
124-
# do the interpolation
124+
# do the interpolation in precision of x
125+
lwgt = self.lon_weights.to(x.dtype)
125126
if self.mode == "bilinear":
126-
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], self.lon_weights)
127+
x = torch.lerp(x[..., self.lon_idx_left], x[..., self.lon_idx_right], lwgt)
127128
else:
128129
omega = x[..., self.lon_idx_right] - x[..., self.lon_idx_left]
129130
somega = torch.sin(omega)
130-
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lon_weights) * omega) / somega, (1.0 - self.lon_weights))
131-
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lon_weights * omega) / somega, self.lon_weights)
131+
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
132+
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
132133
x = start_prefac * x[..., self.lon_idx_left] + end_prefac * x[..., self.lon_idx_right]
133134

134135
return x
@@ -142,14 +143,15 @@ def _expand_poles(self, x: torch.Tensor):
142143
return x
143144

144145
def _upscale_latitudes(self, x: torch.Tensor):
145-
# do the interpolation
146+
# do the interpolation in precision of x
147+
lwgt = self.lat_weights.to(x.dtype)
146148
if self.mode == "bilinear":
147-
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], self.lat_weights)
149+
x = torch.lerp(x[..., self.lat_idx, :], x[..., self.lat_idx + 1, :], lwgt)
148150
else:
149151
omega = x[..., self.lat_idx + 1, :] - x[..., self.lat_idx, :]
150152
somega = torch.sin(omega)
151-
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - self.lat_weights) * omega) / somega, (1.0 - self.lat_weights))
152-
end_prefac = torch.where(somega > 1e-4, torch.sin(self.lat_weights * omega) / somega, self.lat_weights)
153+
start_prefac = torch.where(somega > 1e-4, torch.sin((1.0 - lwgt) * omega) / somega, (1.0 - lwgt))
154+
end_prefac = torch.where(somega > 1e-4, torch.sin(lwgt * omega) / somega, lwgt)
153155
x = start_prefac * x[..., self.lat_idx, :] + end_prefac * x[..., self.lat_idx + 1, :]
154156

155157
return x

0 commit comments

Comments
 (0)