Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support models with unconnected nodes removed from input (LAM) #95

Merged
merged 8 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added
- Add support for models with unconnected nodes dropped from input [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Change trigger for boundary forcings [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add support for automatic loading of anemoi-datasets of more general type [#95](https://github.com/ecmwf/anemoi-inference/pull/95).
- Add initial state output in netcdf format
- Fix: Enable inference when no constant forcings are used
- Add anemoi-transform link to documentation
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def mars_requests(self, *, variables, dates, use_grib_paramid=False, **kwargs):

@cached_property
def _supporting_arrays(self):
return self._metadata.supporting_arrays
return self._metadata._supporting_arrays

@property
def name(self):
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/inference/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ def __init__(self, context, input, variables, variables_mask):
self.variables_mask = variables_mask
assert isinstance(input, DatasetInput), "Currently only boundary forcings from dataset supported."
self.input = input
num_lam, num_other = input.ds.grids
self.spatial_mask = np.array([False] * num_lam + [True] * num_other, dtype=bool)
if "output_mask" in context.checkpoint._supporting_arrays:
self.spatial_mask = ~context.checkpoint.load_supporting_array("output_mask")
else:
self.spatial_mask = np.array([False] * len(input["latitudes"]), dtype=bool)
self.kinds = dict(retrieved=True) # Used for debugging

def __repr__(self):
Expand Down
24 changes: 21 additions & 3 deletions src/anemoi/inference/inputs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,25 @@ class DatasetInput(Input):

def __init__(self, context, args, kwargs):
super().__init__(context)

grid_indices = kwargs.pop("grid_indices", None)

self.args, self.kwargs = args, kwargs
if context.verbosity > 0:
LOG.info(
"Opening dataset with\nargs=%s\nkwargs=%s", json.dumps(args, indent=4), json.dumps(kwargs, indent=4)
)

if grid_indices is None and "grid_indices" in context.checkpoint._supporting_arrays:
grid_indices = context.checkpoint.load_supporting_array("grid_indices")
if context.verbosity > 0:
LOG.info(
"Loading supporting array `grid_indices` from checkpoint, \
the input grid will be reduced accordingly."
)

self.grid_indices = slice(None) if grid_indices is None else grid_indices

@cached_property
def ds(self):
from anemoi.datasets import open_dataset
Expand All @@ -48,11 +61,13 @@ def create_input_state(self, *, date=None):
raise ValueError("`date` must be provided")

date = to_datetime(date)
latitudes = self.ds.latitudes
longitudes = self.ds.longitudes

input_state = dict(
date=date,
latitudes=self.ds.latitudes,
longitudes=self.ds.longitudes,
latitudes=latitudes[self.grid_indices],
longitudes=longitudes[self.grid_indices],
fields=dict(),
)

Expand All @@ -69,7 +84,8 @@ def create_input_state(self, *, date=None):
if variable not in requested_variables:
continue
# Squeeze the data to remove the ensemble dimension
fields[variable] = np.squeeze(data[:, i], axis=1)
values = np.squeeze(data[:, i], axis=1)
fields[variable] = values[:, self.grid_indices]

return input_state

Expand All @@ -82,6 +98,8 @@ def load_forcings(self, *, variables, dates):
data = np.squeeze(data, axis=2)
# Reorder the dimensions to (variable, date, values)
data = np.swapaxes(data, 0, 1)
# apply reduction to `grid_indices`
data = data[..., self.grid_indices]
return data

def _load_dates(self, dates):
Expand Down
9 changes: 4 additions & 5 deletions src/anemoi/inference/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def output_tensor_index_to_variable(self):
@cached_property
def number_of_grid_points(self):
"""Return the number of grid points per fields"""
if "grid_indices" in self._supporting_arrays:
return len(self.load_supporting_array("grid_indices"))
try:
return self._metadata.dataset.shape[-1]
except AttributeError:
Expand Down Expand Up @@ -510,14 +512,13 @@ def _find(x):
_find(y)

if isinstance(x, dict):
if "dataset" in x:
if "dataset" in x and isinstance(x["dataset"], str):
result.append(x["dataset"])

for k, v in x.items():
_find(v)

_find(self._config.dataloader.training.dataset)

return result

def open_dataset_args_kwargs(self, *, use_original_paths, from_dataloader=None):
Expand Down Expand Up @@ -717,9 +718,7 @@ def boundary_forcings_inputs(self, context, input_state):

result = []

output_mask = self._config_model.get("output_mask", None)
if output_mask is not None:
assert output_mask == "cutout", "Currently only cutout as output mask supported."
if "output_mask" in self._supporting_arrays:
result.append(
context.create_boundary_forcings(
self.prognostic_variables,
Expand Down
Loading