@@ -121,14 +121,15 @@ def extra_repr(self):
121121 return f"in_shape={ (self .nlat_in , self .nlon_in )} , out_shape={ (self .nlat_out , self .nlon_out )} "
122122
123123 def _upscale_longitudes (self , x : torch .Tensor ):
124- # do the interpolation
124+ # do the interpolation in precision of x
125+ lwgt = self .lon_weights .to (x .dtype )
125126 if self .mode == "bilinear" :
126- x = torch .lerp (x [..., self .lon_idx_left ], x [..., self .lon_idx_right ], self . lon_weights )
127+ x = torch .lerp (x [..., self .lon_idx_left ], x [..., self .lon_idx_right ], lwgt )
127128 else :
128129 omega = x [..., self .lon_idx_right ] - x [..., self .lon_idx_left ]
129130 somega = torch .sin (omega )
130- start_prefac = torch .where (somega > 1e-4 , torch .sin ((1.0 - self . lon_weights ) * omega ) / somega , (1.0 - self . lon_weights ))
131- end_prefac = torch .where (somega > 1e-4 , torch .sin (self . lon_weights * omega ) / somega , self . lon_weights )
131+ start_prefac = torch .where (somega > 1e-4 , torch .sin ((1.0 - lwgt ) * omega ) / somega , (1.0 - lwgt ))
132+ end_prefac = torch .where (somega > 1e-4 , torch .sin (lwgt * omega ) / somega , lwgt )
132133 x = start_prefac * x [..., self .lon_idx_left ] + end_prefac * x [..., self .lon_idx_right ]
133134
134135 return x
@@ -142,14 +143,15 @@ def _expand_poles(self, x: torch.Tensor):
142143 return x
143144
144145 def _upscale_latitudes (self , x : torch .Tensor ):
145- # do the interpolation
146+ # do the interpolation in precision of x
147+ lwgt = self .lat_weights .to (x .dtype )
146148 if self .mode == "bilinear" :
147- x = torch .lerp (x [..., self .lat_idx , :], x [..., self .lat_idx + 1 , :], self . lat_weights )
149+ x = torch .lerp (x [..., self .lat_idx , :], x [..., self .lat_idx + 1 , :], lwgt )
148150 else :
149151 omega = x [..., self .lat_idx + 1 , :] - x [..., self .lat_idx , :]
150152 somega = torch .sin (omega )
151- start_prefac = torch .where (somega > 1e-4 , torch .sin ((1.0 - self . lat_weights ) * omega ) / somega , (1.0 - self . lat_weights ))
152- end_prefac = torch .where (somega > 1e-4 , torch .sin (self . lat_weights * omega ) / somega , self . lat_weights )
153+ start_prefac = torch .where (somega > 1e-4 , torch .sin ((1.0 - lwgt ) * omega ) / somega , (1.0 - lwgt ))
154+ end_prefac = torch .where (somega > 1e-4 , torch .sin (lwgt * omega ) / somega , lwgt )
153155 x = start_prefac * x [..., self .lat_idx , :] + end_prefac * x [..., self .lat_idx + 1 , :]
154156
155157 return x
0 commit comments