diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fbdcae..6b2a84c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ Keep it human-readable, your future self will thank you! ## [Unreleased] ### Added +- Add support for models with unconnected nodes dropped from input [#95](https://github.com/ecmwf/anemoi-inference/pull/95). +- Change trigger for boundary forcings [#95](https://github.com/ecmwf/anemoi-inference/pull/95). +- Add support for automatic loading of anemoi-datasets of more general type [#95](https://github.com/ecmwf/anemoi-inference/pull/95). - Add initial state output in netcdf format - Fix: Enable inference when no constant forcings are used - Add anemoi-transform link to documentation diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 2d62f8d..de09388 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -312,7 +312,7 @@ def mars_requests(self, *, variables, dates, use_grib_paramid=False, **kwargs): @cached_property def _supporting_arrays(self): - return self._metadata.supporting_arrays + return self._metadata._supporting_arrays @property def name(self): diff --git a/src/anemoi/inference/forcings.py b/src/anemoi/inference/forcings.py index 984d90a..6d3f1c8 100644 --- a/src/anemoi/inference/forcings.py +++ b/src/anemoi/inference/forcings.py @@ -136,8 +136,10 @@ def __init__(self, context, input, variables, variables_mask): self.variables_mask = variables_mask assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported." self.input = input - num_lam, num_other = input.ds.grids - self.spatial_mask = np.array([False] * num_lam + [True] * num_other, dtype=bool) + if "output_mask" in context.checkpoint._supporting_arrays: + self.spatial_mask = ~context.checkpoint.load_supporting_array("output_mask") + else: + self.spatial_mask = np.array([False] * len(input["latitudes"]), dtype=bool) self.kinds = dict(retrieved=True) # Used for debugging def __repr__(self): diff --git a/src/anemoi/inference/inputs/dataset.py b/src/anemoi/inference/inputs/dataset.py index c7122e3..68bdea3 100644 --- a/src/anemoi/inference/inputs/dataset.py +++ b/src/anemoi/inference/inputs/dataset.py @@ -28,12 +28,25 @@ class DatasetInput(Input): def __init__(self, context, args, kwargs): super().__init__(context) + + grid_indices = kwargs.pop("grid_indices", None) + self.args, self.kwargs = args, kwargs if context.verbosity > 0: LOG.info( "Opening dataset with\nargs=%s\nkwargs=%s", json.dumps(args, indent=4), json.dumps(kwargs, indent=4) ) + if grid_indices is None and "grid_indices" in context.checkpoint._supporting_arrays: + grid_indices = context.checkpoint.load_supporting_array("grid_indices") + if context.verbosity > 0: + LOG.info( + "Loading supporting array `grid_indices` from checkpoint, \ + the input grid will be reduced accordingly." + ) + + self.grid_indices = slice(None) if grid_indices is None else grid_indices + @cached_property def ds(self): from anemoi.datasets import open_dataset @@ -48,11 +61,13 @@ def create_input_state(self, *, date=None): raise ValueError("`date` must be provided") date = to_datetime(date) + latitudes = self.ds.latitudes + longitudes = self.ds.longitudes input_state = dict( date=date, - latitudes=self.ds.latitudes, - longitudes=self.ds.longitudes, + latitudes=latitudes[self.grid_indices], + longitudes=longitudes[self.grid_indices], fields=dict(), ) @@ -69,7 +84,8 @@ def create_input_state(self, *, date=None): if variable not in requested_variables: continue # Squeeze the data to remove the ensemble dimension - fields[variable] = np.squeeze(data[:, i], axis=1) + values = np.squeeze(data[:, i], axis=1) + fields[variable] = values[:, self.grid_indices] return input_state @@ -82,6 +98,8 @@ def load_forcings(self, *, variables, dates): data = np.squeeze(data, axis=2) # Reorder the dimensions to (variable, date, values) data = np.swapaxes(data, 0, 1) + # apply reduction to `grid_indices` + data = data[..., self.grid_indices] return data def _load_dates(self, dates): diff --git a/src/anemoi/inference/metadata.py b/src/anemoi/inference/metadata.py index 29a5791..c12b833 100644 --- a/src/anemoi/inference/metadata.py +++ b/src/anemoi/inference/metadata.py @@ -139,6 +139,8 @@ def output_tensor_index_to_variable(self): @cached_property def number_of_grid_points(self): """Return the number of grid points per fields""" + if "grid_indices" in self._supporting_arrays: + return len(self.load_supporting_array("grid_indices")) try: return self._metadata.dataset.shape[-1] except AttributeError: @@ -510,14 +512,13 @@ def _find(x): _find(y) if isinstance(x, dict): - if "dataset" in x: + if "dataset" in x and isinstance(x["dataset"], str): result.append(x["dataset"]) for k, v in x.items(): _find(v) _find(self._config.dataloader.training.dataset) - return result def open_dataset_args_kwargs(self, *, use_original_paths, from_dataloader=None): @@ -717,9 +718,7 @@ def boundary_forcings_inputs(self, context, input_state): result = [] - output_mask = self._config_model.get("output_mask", None) - if output_mask is not None: - assert output_mask == "cutout", "Currently only cutout as output mask supported." + if "output_mask" in self._supporting_arrays: result.append( context.create_boundary_forcings( self.prognostic_variables,