Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
19bdf2b
Add basic scatter function
frodre Sep 8, 2023
c9e553e
Implementation of gather scatter steppers
frodre Sep 11, 2023
d974b3d
Minor fixes
frodre Sep 11, 2023
6a29a8d
More fixes
frodre Sep 12, 2023
75c4196
Merge branch 'master' into feature/reservoir-optimize
frodre Sep 12, 2023
fb351de
Ignore model None for non root tiles
frodre Sep 12, 2023
5002d7e
Working gather scatter
frodre Sep 14, 2023
17fc919
Reduce test logging for gather/scatter
frodre Sep 14, 2023
2650a8d
Synchronize required gather variable definition
frodre Sep 14, 2023
0bc82fd
Fix use of required variables for gathering state items
frodre Sep 14, 2023
503e033
Fix state output bug
frodre Sep 19, 2023
840b96d
Fix prescriber scatter
frodre Sep 29, 2023
82026d5
Merge branch 'master' into test-master-merge
frodre Nov 26, 2023
3a6b319
Fix tests
frodre Nov 26, 2023
2601fdf
Latest fv3fit reservoir model classes
frodre Nov 26, 2023
86a95b3
Add intermediate state saving to reservoir
frodre Nov 28, 2023
818aced
Fix training output trim before encode
frodre Nov 29, 2023
af31be6
Train 8x8 subdomain model
frodre Nov 29, 2023
17d781a
Fix gather scatter increment offset
frodre Nov 30, 2023
9c44f93
Fix missing reference variables for SST update in 6 rank mode
frodre Nov 30, 2023
27669fc
Fix output dim ordering for state application
frodre Dec 2, 2023
f7de8da
Omit SST time averaging during runtime (temporary)
frodre Dec 2, 2023
fcff59a
Fix 24 rank operation
frodre Dec 3, 2023
5e8cbb2
Tmp fix for merge updates to reservoir update
frodre Dec 3, 2023
1613ea8
Fix output variable hyperparemeters in model train
frodre Dec 5, 2023
652c635
Fix prognostic run w/ gather scatter and overlap
frodre Dec 5, 2023
9ff4711
Fix validation data ordering to match runtime for adapter
frodre Dec 6, 2023
21e6cab
Fix use_variable usage on model adapter
frodre Dec 6, 2023
5c8bb06
Try halo 6 training for pure reservoir
frodre Dec 6, 2023
c76ac96
Fix missing use of halo appended reservoir inputs
frodre Dec 6, 2023
d437888
Move training configs to separate directories
frodre Dec 6, 2023
133285b
Switch pure back to better performing halo2 config
frodre Dec 6, 2023
22a1c9a
Fix hybrid training
frodre Dec 6, 2023
7ecb113
Hybrid 8x8 training
frodre Dec 6, 2023
13b0e30
Add next segment syncing script
frodre Jan 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 75 additions & 35 deletions external/fv3fit/fv3fit/data/netcdf/load.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Mapping, Optional, Sequence

from pathlib import Path
import fsspec
import numpy as np
import re
import tensorflow as tf
Expand Down Expand Up @@ -116,9 +118,77 @@ def to_tensor(
return {key: tf.convert_to_tensor(ds[key], dtype=dtype) for key in variable_names}


def open_netcdf_file(path: str) -> xr.Dataset:
"""Open a netcdf from a local/remote path"""
with fsspec.open(path) as f:
ds = xr.open_dataset(f, engine="h5netcdf")
return ds.load()


class _BaseNCLoader(TFDatasetLoader):
@property
@abstractmethod
def dim_order(self) -> Optional[Sequence[str]]:
pass

@property
def dtype(self):
return tf.float32

def convert(
self, ds: xr.Dataset, variables: Sequence[str]
) -> Mapping[str, tf.Tensor]:
tensors = {}
for key in variables:
data_array = self._ensure_consistent_dims(ds[key])
tensors[key] = tf.convert_to_tensor(data_array, dtype=self.dtype)
return tensors

@classmethod
def from_dict(cls, d: dict) -> "TFDatasetLoader":
return dacite.from_dict(cls, d, config=dacite.Config(strict=True))

def _ensure_consistent_dims(self, data_array: xr.DataArray):
if self.dim_order:
extra_dims_in_data_array = set(data_array.dims) - set(self.dim_order)
missing_dims_in_data_array = set(self.dim_order) - set(data_array.dims)
if len(extra_dims_in_data_array) > 0:
raise ValueError(
f"Extra dimensions {extra_dims_in_data_array} in data that are not "
f"included in configured dimension order {self.dim_order}."
"Make sure these are included in the configuration dim_order."
)
for missing_dim in missing_dims_in_data_array:
data_array = data_array.expand_dims(dim=missing_dim)
data_array = data_array.transpose(*self.dim_order)
return data_array


@register_tfdataset_loader
@dataclass
class NCFileLoader(_BaseNCLoader):
"""
Loads a single remote/local netCDF file into a dataset
"""

filepath: str = ""
dim_order: Optional[Sequence[str]] = None

def open_tfdataset(
self, local_download_path: Optional[str], variable_names: Sequence[str],
) -> tf.data.Dataset:
def convert(x):
return self.convert(x, variable_names)

transform = compose_left(open_netcdf_file, convert)
return iterable_to_tfdataset(
[self.filepath], transform, varying_first_dim=False
).prefetch(tf.data.AUTOTUNE)


@register_tfdataset_loader
@dataclass
class NCDirLoader(TFDatasetLoader):
class NCDirLoader(_BaseNCLoader):
"""Loads a folder of netCDF files at given path

Each file must have identical CDL scheme returned by ``ncdump -h``.
Expand Down Expand Up @@ -151,26 +221,14 @@ class NCDirLoader(TFDatasetLoader):

"""

url: str
url: str = ""
dim_order: Optional[Sequence[str]] = None
nfiles: Optional[int] = None
shuffle: bool = True
seed: int = 0
dim_order: Optional[Sequence[str]] = None
varying_first_dim: bool = False
sort_files: bool = False

@property
def dtype(self):
return tf.float32

def convert(
self, ds: xr.Dataset, variables: Sequence[str]
) -> Mapping[str, tf.Tensor]:
tensors = {}
for key in variables:
data_array = self._ensure_consistent_dims(ds[key])
tensors[key] = tf.convert_to_tensor(data_array, dtype=self.dtype)
return tensors
match: Optional[str] = None

def open_tfdataset(
self, local_download_path: Optional[str], variable_names: Sequence[str],
Expand All @@ -187,23 +245,5 @@ def convert(x):
cache=local_download_path,
varying_first_dim=self.varying_first_dim,
sort_files=self.sort_files,
match=self.match,
)

@classmethod
def from_dict(cls, d: dict) -> "TFDatasetLoader":
return dacite.from_dict(cls, d, config=dacite.Config(strict=True))

def _ensure_consistent_dims(self, data_array: xr.DataArray):
if self.dim_order:
extra_dims_in_data_array = set(data_array.dims) - set(self.dim_order)
missing_dims_in_data_array = set(self.dim_order) - set(data_array.dims)
if len(extra_dims_in_data_array) > 0:
raise ValueError(
f"Extra dimensions {extra_dims_in_data_array} in data that are not "
f"included in configured dimension order {self.dim_order}."
"Make sure these are included in the configuration dim_order."
)
for missing_dim in missing_dims_in_data_array:
data_array = data_array.expand_dims(dim=missing_dim)
data_array = data_array.transpose(*self.dim_order)
return data_array
17 changes: 10 additions & 7 deletions external/fv3fit/fv3fit/reservoir/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import os
import typing
from typing import Iterable, Hashable, Sequence, Union, Mapping
from typing import Iterable, Hashable, Sequence, Union, Mapping, Optional
import xarray as xr

import fv3fit
Expand Down Expand Up @@ -55,15 +55,19 @@ def _ndarray_to_dataarray(self, arr: np.ndarray) -> xr.DataArray:
return xr.DataArray(data=arr, dims=dims)

def output_array_to_ds(
self, outputs: Sequence[np.ndarray], output_dims: Sequence[str]
self, outputs: Sequence[np.ndarray], output_dims: Optional[Sequence[str]] = None
) -> xr.Dataset:

return xr.Dataset(
ds = xr.Dataset(
{
var: self._ndarray_to_dataarray(output)
for var, output in zip(self.output_variables, outputs)
}
).transpose(*output_dims)
)
if output_dims is None:
output_dims = ["y", "x", "z"] # default ordering for wrapper

return ds.transpose(*[dim for dim in output_dims if dim in ds.dims])

def input_dataset_to_arrays(
self, inputs: xr.Dataset, variables: Iterable[Hashable]
Expand Down Expand Up @@ -121,9 +125,8 @@ def is_hybrid(self):
def predict(self, inputs: xr.Dataset) -> xr.Dataset:
# inputs arg is not used, but is required by Predictor signature and prog run
prediction_arr = self.model.predict()
return self.model_adapter.output_array_to_ds(
prediction_arr, output_dims=list(inputs.dims)
)
dims = list(inputs.dims) if inputs else None
return self.model_adapter.output_array_to_ds(prediction_arr, output_dims=dims)

def increment_state(self, inputs: xr.Dataset):
xy_input_arrs = self.model_adapter.input_dataset_to_arrays(
Expand Down
6 changes: 5 additions & 1 deletion external/fv3fit/fv3fit/reservoir/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dacite
from dataclasses import dataclass, asdict
from typing import Sequence, Optional, Set, Tuple
from typing import Sequence, Tuple, Optional, Set
import fsspec
import yaml
from .._shared.training_config import Hyperparameters
Expand Down Expand Up @@ -91,6 +91,8 @@ class ReservoirTrainingConfig(Hyperparameters):
mask_variable: if specified, save mask array that is multiplied to the input array
before multiplication with W_in. This applies a mask using the
mask_variable field.
mask_readout: if mask_variable is specified, apply that mask to the hybrid inputs
as well
"""

input_variables: Sequence[str]
Expand All @@ -106,6 +108,8 @@ class ReservoirTrainingConfig(Hyperparameters):
square_half_hidden_state: bool = False
hybrid_variables: Optional[Sequence[str]] = None
mask_variable: Optional[str] = None
mask_readout: bool = True
validate_sst_only: bool = False
_METADATA_NAME = "reservoir_training_config.yaml"

def __post_init__(self):
Expand Down
15 changes: 13 additions & 2 deletions external/fv3fit/fv3fit/reservoir/domain2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ def __init__(
if overlap < 0:
raise ValueError("Overlap must be non-negative")

self._rank_dims = ["x", "y"]

self.overlap = overlap
self.subdomain_layout = subdomain_layout
self.n_subdomains = subdomain_layout[0] * subdomain_layout[1]
self._partitioner = pace.util.TilePartitioner(subdomain_layout)
self._partitioner = pace.util.TilePartitioner(
self._subdomain_layout_for_partitioner
)

self._init_rank_extent(rank_extent, overlap_rank_extent)
self._x_rank_extent = self.rank_extent[0]
Expand All @@ -81,6 +85,13 @@ def _rank_extent_for_partitioner(self):
# Fed into partitioner for slicing, no overlap should ever be given
return self._maybe_append_feature_value(self.rank_extent, self._z_feature_size)

@property
def _subdomain_layout_for_partitioner(self):
# partitioner expects layout in y, x order:
x_ind = self._rank_dims.index("x")
y_ind = self._rank_dims.index("y")
return self.subdomain_layout[y_ind], self.subdomain_layout[x_ind]

@property
def _rank_extent_all_features(self):
# used for data consistency checks
Expand All @@ -90,7 +101,7 @@ def _rank_extent_all_features(self):

@property
def _rank_dims_all_features(self):
return self._maybe_append_feature_value(["x", "y"], "z")
return self._maybe_append_feature_value(self._rank_dims, "z")

@property
def _subdomain_shape(self):
Expand Down
22 changes: 8 additions & 14 deletions external/fv3fit/fv3fit/reservoir/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .domain2 import RankXYDivider
from fv3fit._shared import io
from .utils import square_even_terms
from .transformers import encode_columns, decode_columns, TransformerGroup
from .transformers import TransformerGroup

DIMENSION_ORDER = ("x", "y")

Expand Down Expand Up @@ -94,9 +94,7 @@ def from_reservoir_model(
def predict(self, hybrid_input: Sequence[np.ndarray]):
# hybrid input is assumed to be in original spatial xy dims
# (x, y, feature) and does not include overlaps.
encoded_hybrid_input = encode_columns(
input_arrs=hybrid_input, transformer=self.transformers.hybrid
)
encoded_hybrid_input = self.transformers.hybrid.encode_txyz(hybrid_input)

flat_hybrid_in = self._hybrid_rank_divider.get_all_subdomains_with_flat_feature(
encoded_hybrid_input
Expand All @@ -113,9 +111,7 @@ def predict(self, hybrid_input: Sequence[np.ndarray]):
prediction = self._output_rank_divider.merge_all_flat_feature_subdomains(
flat_prediction
)
decoded_prediction = decode_columns(
encoded_output=prediction, transformer=self.transformers.output,
)
decoded_prediction = self.transformers.output.decode_txyz(prediction)
return decoded_prediction

def _concatenate_readout_inputs(self, hidden_state_input, flat_hybrid_input):
Expand Down Expand Up @@ -235,9 +231,7 @@ def predict(self):
prediction = self._output_rank_divider.merge_all_flat_feature_subdomains(
flat_prediction
)
decoded_prediction = decode_columns(
encoded_output=prediction, transformer=self.transformers.output,
)
decoded_prediction = self.transformers.output.decode_txyz(prediction)
return decoded_prediction

def reset_state(self):
Expand All @@ -249,8 +243,8 @@ def reset_state(self):

def increment_state(self, prediction_with_overlap: Sequence[np.ndarray]) -> None:
# input array is in native x, y, z_feature coordinates
encoded_xy_input_arrs = encode_columns(
prediction_with_overlap, self.transformers.input
encoded_xy_input_arrs = self.transformers.input.encode_txyz(
prediction_with_overlap
)
encoded_flat_sub = self.rank_divider.get_all_subdomains_with_flat_feature(
encoded_xy_input_arrs
Expand All @@ -259,8 +253,8 @@ def increment_state(self, prediction_with_overlap: Sequence[np.ndarray]) -> None

def synchronize(self, synchronization_time_series):
# input arrays in native x, y, z_feature coordinates
encoded_timeseries = encode_columns(
synchronization_time_series, self.transformers.input
encoded_timeseries = self.transformers.input.encode_txyz(
synchronization_time_series
)
encoded_flat = self.rank_divider.get_all_subdomains_with_flat_feature(
encoded_timeseries
Expand Down
14 changes: 11 additions & 3 deletions external/fv3fit/fv3fit/reservoir/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import os
import scipy
from typing import Optional
from typing import cast, Optional
import yaml

from .config import ReservoirHyperparameters
Expand Down Expand Up @@ -75,7 +75,9 @@ def increment_state(self, input):
masked_input = input * self.input_mask_array
else:
masked_input = input
self.state = np.tanh(masked_input @ self.W_in.T + self.state @ self.W_res.T)
self.state: np.ndarray = np.tanh(
masked_input @ self.W_in.T + self.state @ self.W_res.T
)

def reset_state(self, input_shape: tuple):
logger.info("Resetting reservoir state.")
Expand All @@ -92,6 +94,12 @@ def reset_state(self, input_shape: tuple):
raise ValueError("Input shape tuple must describe either a 1D or 2D array.")
self.state = state_after_reset

def set_state(self, new_state: np.ndarray):
if self.state is not None:
if self.state.shape != new_state.shape:
raise ValueError("Provided state does not match reservoir state shape")
self.state = new_state

def synchronize(self, synchronization_time_series):
self.reset_state(input_shape=synchronization_time_series[0].shape)
for input in synchronization_time_series:
Expand Down Expand Up @@ -178,7 +186,7 @@ def load(cls, path: str) -> "Reservoir":

try:
with fsspec.open(os.path.join(path, cls._STATE_NAME), "rb") as f:
state = np.load(f)
state = cast(np.ndarray, np.load(f))
except (FileNotFoundError):
state = None

Expand Down
Loading