diff --git a/fme/downscaling/aggregators/main.py b/fme/downscaling/aggregators/main.py index 1d1b4b138..65479867f 100644 --- a/fme/downscaling/aggregators/main.py +++ b/fme/downscaling/aggregators/main.py @@ -534,6 +534,8 @@ def __init__( self._name = ensure_trailing_slash(name) self._mean_target = Mean(batch_mean) self._mean_prediction = Mean(batch_mean) + # corresponds to 40 x 80 deg at 3km resolution, in comparison CONUS is 28x65 deg + self.max_img_upload_size = 3276800 @torch.no_grad() def record_batch(self, target: TensorMapping, prediction: TensorMapping) -> None: @@ -583,19 +585,20 @@ def get_relative_mean(target, prediction): metrics = {} spectra = {} for var_name in target.keys(): - gap = torch.full( - (target[var_name].shape[-2], self.gap_width), - float(target[var_name].min()), - device=target[var_name].device, - ) - maps[f"maps/{self._name}full-field/{var_name}"] = torch.cat( - (prediction[var_name], gap, target[var_name]), dim=1 - ) - maps[f"maps/{self._name}log10_relative_mean/{var_name}"] = relative[ - var_name - ] error = prediction[var_name] - target[var_name] - maps[f"maps/{self._name}error/{var_name}"] = error + if target[var_name].nelement() < self.max_img_upload_size: + gap = torch.full( + (target[var_name].shape[-2], self.gap_width), + float(target[var_name].min()), + device=target[var_name].device, + ) + maps[f"maps/{self._name}full-field/{var_name}"] = torch.cat( + (prediction[var_name], gap, target[var_name]), dim=1 + ) + maps[f"maps/{self._name}log10_relative_mean/{var_name}"] = relative[ + var_name + ] + maps[f"maps/{self._name}error/{var_name}"] = error metrics[f"metrics/{self._name}bias/{var_name}"] = error.mean() spectra_prefix = ensure_trailing_slash(f"power_spectrum_of_{self._name}")