diff --git a/external/fv3fit/fv3fit/reservoir/config.py b/external/fv3fit/fv3fit/reservoir/config.py index 869ea263df..091e9438b3 100644 --- a/external/fv3fit/fv3fit/reservoir/config.py +++ b/external/fv3fit/fv3fit/reservoir/config.py @@ -65,7 +65,7 @@ class ReservoirTrainingConfig(Hyperparameters): output_variables: time series variables, must be subset of input_variables reservoir_hyperparameters: hyperparameters for reservoir readout_hyperparameters: hyperparameters for readout - n_batches_burn: number of training batches at start of time series to use + n_timesteps_synchronize: number of timesteps at start of time series to use for synchronizaton. This data is used to update the reservoir state but is not included in training. input_noise: stddev of normal distribution which is sampled to add input @@ -88,21 +88,23 @@ class ReservoirTrainingConfig(Hyperparameters): subdomain: CubedsphereSubdomainConfig reservoir_hyperparameters: ReservoirHyperparameters readout_hyperparameters: BatchLinearRegressorHyperparameters - n_batches_burn: int + n_timesteps_synchronize: int input_noise: float seed: int = 0 n_jobs: Optional[int] = 1 square_half_hidden_state: bool = False autoencoder_path: Optional[str] = None + hybrid_autoencoder_path: Optional[str] = None hybrid_variables: Optional[Sequence[str]] = None _METADATA_NAME = "reservoir_training_config.yaml" def __post_init__(self): if set(self.output_variables).issubset(self.input_variables) is False: - raise ValueError( - f"Output variables {self.output_variables} must be a subset of " - f"input variables {self.input_variables}." - ) + if len(set(self.output_variables).intersection(self.input_variables)) > 0: + raise ValueError( + f"Output variables {self.output_variables} must either be a subset " + f"of input variables {self.input_variables} or mutually exclusive." + ) if self.hybrid_variables is not None: hybrid_and_input_vars_intersection = set( self.hybrid_variables @@ -119,7 +121,9 @@ def variables(self) -> Set[str]: hybrid_vars = list(self.hybrid_variables) # type: ignore else: hybrid_vars = [] - return set(list(self.input_variables) + hybrid_vars) + return set( + list(self.input_variables) + list(self.output_variables) + hybrid_vars + ) @classmethod def from_dict(cls, kwargs) -> "ReservoirTrainingConfig": @@ -148,7 +152,7 @@ def from_dict(cls, kwargs) -> "ReservoirTrainingConfig": def dump(self, path: str): metadata = { - "n_batches_burn": self.n_batches_burn, + "n_timesteps_synchronize": self.n_timesteps_synchronize, "input_noise": self.input_noise, "seed": self.seed, "n_jobs": self.n_jobs, diff --git a/external/fv3fit/fv3fit/reservoir/domain.py b/external/fv3fit/fv3fit/reservoir/domain.py index 267c096300..1771534f2d 100644 --- a/external/fv3fit/fv3fit/reservoir/domain.py +++ b/external/fv3fit/fv3fit/reservoir/domain.py @@ -203,7 +203,7 @@ def dump(self, path): metadata = { "subdomain_layout": self.subdomain_layout, "rank_dims": self.rank_dims, - "rank_extent": self.rank_extent, + "rank_extent": list(self.rank_extent), "overlap": self.overlap, } with fsspec.open(path, "w") as f: diff --git a/external/fv3fit/fv3fit/reservoir/model.py b/external/fv3fit/fv3fit/reservoir/model.py index 3ea6887b4f..7cfc0ce88a 100644 --- a/external/fv3fit/fv3fit/reservoir/model.py +++ b/external/fv3fit/fv3fit/reservoir/model.py @@ -1,7 +1,7 @@ import fsspec import numpy as np import os -from typing import Iterable, Hashable, Sequence, cast +from typing import Iterable, Hashable, Sequence, cast, Optional import xarray as xr import yaml @@ -35,6 +35,7 @@ def _transpose_xy_dims(ds: xr.Dataset, rank_dims: Sequence[str]): @io.register("hybrid-reservoir") class HybridReservoirComputingModel(Predictor): _HYBRID_VARIABLES_NAME = "hybrid_variables.yaml" + _AUTOENCODER_SUBDIR = "autoencoder" def __init__( self, @@ -45,6 +46,8 @@ def __init__( readout: ReservoirComputingReadout, rank_divider: RankDivider, autoencoder: ReloadableTransfomer, + output_autoencoder: Optional[ReloadableTransfomer] = None, + hybrid_autoencoder: Optional[ReloadableTransfomer] = None, square_half_hidden_state: bool = False, ): self.reservoir_model = ReservoirComputingModel( @@ -55,6 +58,7 @@ def __init__( square_half_hidden_state=square_half_hidden_state, rank_divider=rank_divider, autoencoder=autoencoder, + output_autoencoder=output_autoencoder, ) self.input_variables = input_variables self.hybrid_variables = hybrid_variables @@ -62,13 +66,16 @@ def __init__( self.readout = readout self.square_half_hidden_state = square_half_hidden_state self.rank_divider = rank_divider - self.autoencoder = autoencoder + self.input_autoencoder = autoencoder + self.output_autoencoder = output_autoencoder + self.hybrid_autoencoder = hybrid_autoencoder def predict(self, hybrid_input: Sequence[np.ndarray]): # hybrid input is assumed to be in original spatial xy dims # (x, y, feature) and does not include overlaps. + hybrid_autoencoder = self.hybrid_autoencoder or self.input_autoencoder encoded_hybrid_input = encode_columns( - input_arrs=hybrid_input, transformer=self.autoencoder + input_arrs=hybrid_input, transformer=hybrid_autoencoder ) if encoded_hybrid_input.shape[:2] != tuple( self.rank_divider.rank_extent_without_overlap @@ -93,9 +100,11 @@ def predict(self, hybrid_input: Sequence[np.ndarray]): ) flat_prediction = self.readout.predict(flattened_readout_input).reshape(-1) prediction = self.rank_divider.merge_subdomains(flat_prediction) + + output_autoencoder = self.output_autoencoder or self.input_autoencoder decoded_prediction = decode_columns( encoded_output=prediction, - transformer=self.autoencoder, + transformer=output_autoencoder, xy_shape=self.rank_divider.rank_extent_without_overlap, ) return decoded_prediction @@ -122,12 +131,26 @@ def dump(self, path: str) -> None: self.reservoir_model.dump(path) with fsspec.open(os.path.join(path, self._HYBRID_VARIABLES_NAME), "w") as f: f.write(yaml.dump({"hybrid_variables": self.hybrid_variables})) + if self.hybrid_autoencoder is not None: + fv3fit.dump( + self.hybrid_autoencoder, + os.path.join(path, self._AUTOENCODER_SUBDIR, "hybrid"), + ) @classmethod def load(cls, path: str) -> "HybridReservoirComputingModel": pure_reservoir_model = ReservoirComputingModel.load(path) with fsspec.open(os.path.join(path, cls._HYBRID_VARIABLES_NAME), "r") as f: hybrid_variables = yaml.safe_load(f)["hybrid_variables"] + + try: + hybrid_autoencoder = cast( + ReloadableTransfomer, + fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "hybrid")), + ) + except (KeyError): + hybrid_autoencoder = None # type: ignore + return cls( input_variables=pure_reservoir_model.input_variables, output_variables=pure_reservoir_model.output_variables, @@ -135,7 +158,9 @@ def load(cls, path: str) -> "HybridReservoirComputingModel": readout=pure_reservoir_model.readout, square_half_hidden_state=pure_reservoir_model.square_half_hidden_state, rank_divider=pure_reservoir_model.rank_divider, - autoencoder=pure_reservoir_model.autoencoder, + autoencoder=pure_reservoir_model.input_autoencoder, + output_autoencoder=pure_reservoir_model.output_autoencoder, + hybrid_autoencoder=hybrid_autoencoder, hybrid_variables=hybrid_variables, ) @@ -200,6 +225,7 @@ def __init__( readout: ReservoirComputingReadout, rank_divider: RankDivider, autoencoder: ReloadableTransfomer, + output_autoencoder: Optional[ReloadableTransfomer] = None, square_half_hidden_state: bool = False, ): """_summary_ @@ -219,7 +245,8 @@ def __init__( self.readout = readout self.square_half_hidden_state = square_half_hidden_state self.rank_divider = rank_divider - self.autoencoder = autoencoder + self.input_autoencoder = autoencoder + self.output_autoencoder = output_autoencoder def process_state_to_readout_input(self): if self.square_half_hidden_state is True: @@ -236,9 +263,10 @@ def predict(self): readout_input = self.process_state_to_readout_input() flat_prediction = self.readout.predict(readout_input).reshape(-1) prediction = self.rank_divider.merge_subdomains(flat_prediction) + output_autoencoder = self.output_autoencoder or self.input_autoencoder decoded_prediction = decode_columns( encoded_output=prediction, - transformer=self.autoencoder, + transformer=output_autoencoder, xy_shape=self.rank_divider.rank_extent_without_overlap, ) return decoded_prediction @@ -253,7 +281,7 @@ def reset_state(self): def increment_state(self, prediction_with_overlap: Sequence[np.ndarray]) -> None: # input array is in native x, y, z_feature coordinates encoded_xy_input_arrs = encode_columns( - prediction_with_overlap, self.autoencoder + prediction_with_overlap, self.input_autoencoder ) encoded_flattened_subdomains = self.rank_divider.flatten_subdomains_to_columns( encoded_xy_input_arrs, with_overlap=True @@ -281,8 +309,15 @@ def dump(self, path: str) -> None: f.write(yaml.dump(metadata)) self.rank_divider.dump(os.path.join(path, self._RANK_DIVIDER_NAME)) - if self.autoencoder is not None: - fv3fit.dump(self.autoencoder, os.path.join(path, self._AUTOENCODER_SUBDIR)) + fv3fit.dump( + self.input_autoencoder, + os.path.join(path, self._AUTOENCODER_SUBDIR, "input"), + ) + if self.output_autoencoder is not None: + fv3fit.dump( + self.output_autoencoder, + os.path.join(path, self._AUTOENCODER_SUBDIR, "output"), + ) @classmethod def load(cls, path: str) -> "ReservoirComputingModel": @@ -298,8 +333,16 @@ def load(cls, path: str) -> "ReservoirComputingModel": autoencoder = cast( ReloadableTransfomer, - fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR)), + fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "input")), ) + try: + output_autoencoder = cast( + ReloadableTransfomer, + fv3fit.load(os.path.join(path, cls._AUTOENCODER_SUBDIR, "output")), + ) + except (KeyError): + output_autoencoder = None # type: ignore + return cls( input_variables=metadata["input_variables"], output_variables=metadata["output_variables"], @@ -308,4 +351,5 @@ def load(cls, path: str) -> "ReservoirComputingModel": square_half_hidden_state=metadata["square_half_hidden_state"], rank_divider=rank_divider, autoencoder=autoencoder, + output_autoencoder=output_autoencoder, ) diff --git a/external/fv3fit/fv3fit/reservoir/train.py b/external/fv3fit/fv3fit/reservoir/train.py index e12f46eac3..59c22455dc 100644 --- a/external/fv3fit/fv3fit/reservoir/train.py +++ b/external/fv3fit/fv3fit/reservoir/train.py @@ -4,9 +4,14 @@ from fv3fit.reservoir.readout import BatchLinearRegressor import numpy as np import tensorflow as tf -from typing import Optional, List, Union +from typing import Optional, List, Union, cast from .. import Predictor -from .utils import square_even_terms, process_batch_Xy_data, get_ordered_X +from .utils import ( + square_even_terms, + process_batch_Xy_data, + get_ordered_X, + SynchronziationTracker, +) from .transformers.autoencoder import build_concat_and_scale_only_autoencoder from .._shared import register_training_function from ._reshaping import concat_inputs_along_subdomain_features @@ -31,6 +36,16 @@ def _add_input_noise(arr: np.ndarray, stddev: float) -> np.ndarray: return arr + np.random.normal(loc=0, scale=stddev, size=arr.shape) +def _get_standard_normalizing_transformer(variables, sample_batch): + variable_data = get_ordered_X(sample_batch, variables) + variable_data_stacked = [ + stack_array_preserving_last_dim(arr).numpy() for arr in variable_data + ] + return build_concat_and_scale_only_autoencoder( + variables=variables, X=variable_data_stacked + ) + + @register_training_function("reservoir", ReservoirTrainingConfig) def train_reservoir_model( hyperparameters: ReservoirTrainingConfig, @@ -42,17 +57,39 @@ def train_reservoir_model( sample_X = get_ordered_X(sample_batch, hyperparameters.input_variables) if hyperparameters.autoencoder_path is not None: - autoencoder: ReloadableTransfomer = fv3fit.load( - hyperparameters.autoencoder_path - ) # type: ignore + input_autoencoder = cast( + ReloadableTransfomer, fv3fit.load(hyperparameters.autoencoder_path) + ) else: - sample_X_stacked = [ - stack_array_preserving_last_dim(arr).numpy() for arr in sample_X - ] - autoencoder = build_concat_and_scale_only_autoencoder( - variables=hyperparameters.input_variables, X=sample_X_stacked + input_autoencoder = _get_standard_normalizing_transformer( + hyperparameters.input_variables, sample_batch ) + if hyperparameters.hybrid_variables is not None: + if hyperparameters.hybrid_autoencoder_path is not None: + hybrid_autoencoder = cast( + ReloadableTransfomer, + fv3fit.load(hyperparameters.hybrid_autoencoder_path), + ) + else: + hybrid_autoencoder = _get_standard_normalizing_transformer( + hyperparameters.hybrid_variables, sample_batch + ) + + # If output variables are different from inputs, need to have a separate + # autoencoder for decoding outputs + if set(hyperparameters.output_variables) != set(hyperparameters.input_variables): + if hyperparameters.autoencoder_path is not None: + raise ValueError( + "Output variables != input variables, cannot use the same " + f"autoencoder {hyperparameters.autoencoder_path} for both. " + "This feature will be added in the future." + ) + output_autoencoder = _get_standard_normalizing_transformer( + hyperparameters.output_variables, sample_batch + ) + else: + output_autoencoder = input_autoencoder subdomain_config = hyperparameters.subdomain # sample_X[0] is the first data variable, shape elements 1:-1 are the x,y shape @@ -67,7 +104,8 @@ def train_reservoir_model( # subdomain+halo are are flattened into feature dimension reservoir = Reservoir( hyperparameters=hyperparameters.reservoir_hyperparameters, - input_size=rank_divider.subdomain_size_with_overlap * autoencoder.n_latent_dims, + input_size=rank_divider.subdomain_size_with_overlap + * input_autoencoder.n_latent_dims, ) # One readout is trained per subdomain when iterating over batches, @@ -76,29 +114,30 @@ def train_reservoir_model( BatchLinearRegressor(hyperparameters.readout_hyperparameters) for r in range(rank_divider.n_subdomains) ] + sync_tracker = SynchronziationTracker( + n_synchronize=hyperparameters.n_timesteps_synchronize + ) for b, batch_data in enumerate(train_batches): time_series_with_overlap, time_series_without_overlap = process_batch_Xy_data( variables=hyperparameters.input_variables, batch_data=batch_data, rank_divider=rank_divider, - autoencoder=autoencoder, + autoencoder=input_autoencoder, ) - if b < hyperparameters.n_batches_burn: - logger.info(f"Synchronizing on batch {b+1}") - # reservoir increment occurs in this call, so always call this # function even if X, Y are not used for readout training. reservoir_state_time_series = _get_reservoir_state_time_series( time_series_with_overlap, hyperparameters.input_noise, reservoir ) + sync_tracker.count(len(reservoir_state_time_series)) hybrid_time_series: Optional[np.ndarray] if hyperparameters.hybrid_variables is not None: _, hybrid_time_series = process_batch_Xy_data( variables=hyperparameters.hybrid_variables, batch_data=batch_data, rank_divider=rank_divider, - autoencoder=autoencoder, + autoencoder=hybrid_autoencoder, ) else: hybrid_time_series = None @@ -110,7 +149,26 @@ def train_reservoir_model( hybrid_time_series=hybrid_time_series, ) - if b >= hyperparameters.n_batches_burn: + if set(hyperparameters.input_variables) != (hyperparameters.output_variables): + _, output_time_series_without_overlap = process_batch_Xy_data( + variables=hyperparameters.output_variables, + batch_data=batch_data, + rank_divider=rank_divider, + autoencoder=output_autoencoder, + ) + logger.info( + f"Using output_variables {hyperparameters.output_variables}, " + f"which differ from input_variables {hyperparameters.input_variables}" + ) + readout_output = output_time_series_without_overlap[:-1] + + if sync_tracker.completed_synchronization: + readout_input = sync_tracker.trim_synchronization_samples_if_needed( + readout_input + ) + readout_output = sync_tracker.trim_synchronization_samples_if_needed( + readout_output + ) logger.info(f"Fitting on batch {b+1}") _fit_batch_over_subdomains( X_batch=readout_input, @@ -137,23 +195,26 @@ def train_reservoir_model( if hyperparameters.hybrid_variables is None: model = ReservoirComputingModel( input_variables=hyperparameters.input_variables, - output_variables=hyperparameters.input_variables, + output_variables=hyperparameters.output_variables, reservoir=reservoir, readout=readout, square_half_hidden_state=hyperparameters.square_half_hidden_state, rank_divider=rank_divider, - autoencoder=autoencoder, + autoencoder=input_autoencoder, + output_autoencoder=output_autoencoder, ) else: model = HybridReservoirComputingModel( input_variables=hyperparameters.input_variables, - output_variables=hyperparameters.input_variables, + output_variables=hyperparameters.output_variables, hybrid_variables=hyperparameters.hybrid_variables, reservoir=reservoir, readout=readout, square_half_hidden_state=hyperparameters.square_half_hidden_state, rank_divider=rank_divider, - autoencoder=autoencoder, + autoencoder=input_autoencoder, + output_autoencoder=output_autoencoder, + hybrid_autoencoder=hybrid_autoencoder, ) return model diff --git a/external/fv3fit/fv3fit/reservoir/transformers/autoencoder.py b/external/fv3fit/fv3fit/reservoir/transformers/autoencoder.py index 9a9a693372..fdbccabf99 100644 --- a/external/fv3fit/fv3fit/reservoir/transformers/autoencoder.py +++ b/external/fv3fit/fv3fit/reservoir/transformers/autoencoder.py @@ -4,6 +4,8 @@ import tensorflow as tf from toolz.functoolz import curry from typing import Union, Sequence, Optional, List, Set, Tuple +from tensorflow.python.keras.utils.generic_utils import to_list + from fv3fit.reservoir.transformers.transformer import Transformer from fv3fit._shared import ( get_dir, @@ -68,8 +70,8 @@ def encode(self, x: Sequence[ArrayLike]) -> ArrayLike: x = _ensure_all_items_have_sample_dim(x) return self.encoder.predict(x) - def decode(self, latent_x: ArrayLike) -> ArrayLike: - return self.decoder.predict(latent_x) + def decode(self, latent_x: ArrayLike) -> Sequence[ArrayLike]: + return to_list(self.decoder.predict(latent_x)) def dump(self, path: str) -> None: with put_dir(path) as path: diff --git a/external/fv3fit/fv3fit/reservoir/utils.py b/external/fv3fit/fv3fit/reservoir/utils.py index e07735bc77..1317dbaa1a 100644 --- a/external/fv3fit/fv3fit/reservoir/utils.py +++ b/external/fv3fit/fv3fit/reservoir/utils.py @@ -5,6 +5,35 @@ from fv3fit.reservoir.domain import RankDivider, assure_txyz_dims +class SynchronziationTracker: + def __init__(self, n_synchronize: int): + self.n_synchronize = n_synchronize + self.n_steps_synchronized = 0 + + @property + def completed_synchronization(self): + if self.n_steps_synchronized > self.n_synchronize: + return True + else: + return False + + def count(self, n_samples: int): + self.n_steps_synchronized += n_samples + + def trim_synchronization_samples_if_needed(self, arr: np.ndarray) -> np.ndarray: + """ Removes samples from the input array if they fall within the + synchronization range. + """ + if self.completed_synchronization is True: + steps_past_sync = self.n_steps_synchronized - self.n_synchronize + if steps_past_sync > len(arr): + return arr + else: + return arr[-steps_past_sync:] + else: + return np.array([]) + + def _square_evens(v: np.ndarray) -> np.ndarray: evens = v[::2] odds = v[1::2] diff --git a/external/fv3fit/tests/reservoir/test_utils.py b/external/fv3fit/tests/reservoir/test_utils.py index 85807385b9..3d00da1721 100644 --- a/external/fv3fit/tests/reservoir/test_utils.py +++ b/external/fv3fit/tests/reservoir/test_utils.py @@ -1,10 +1,25 @@ import numpy as np import pytest -from fv3fit.reservoir.utils import square_even_terms, process_batch_Xy_data +from fv3fit.reservoir.utils import ( + square_even_terms, + process_batch_Xy_data, + SynchronziationTracker, +) from fv3fit.reservoir.transformers import DoNothingAutoencoder from fv3fit.reservoir.domain import RankDivider +def test_SynchronziationTracker(): + sync_tracker = SynchronziationTracker(n_synchronize=6) + batches = np.arange(15).reshape(3, 5) + expected = [np.array([]), np.array([6, 7, 8, 9]), np.array([10, 11, 12, 13, 14])] + for expected_trimmed, batch in zip(expected, batches): + sync_tracker.count(len(batch)) + np.testing.assert_array_equal( + sync_tracker.trim_synchronization_samples_if_needed(batch), expected_trimmed + ) + + @pytest.mark.parametrize( "arr, axis, expected", [ diff --git a/external/fv3fit/tests/training/test_reservoir.py b/external/fv3fit/tests/training/test_reservoir.py index 23a6c4427e..43b68f5283 100644 --- a/external/fv3fit/tests/training/test_reservoir.py +++ b/external/fv3fit/tests/training/test_reservoir.py @@ -55,7 +55,7 @@ def test_train_reservoir(): subdomain=subdomain_config, reservoir_hyperparameters=reservoir_config, readout_hyperparameters=reg_config, - n_batches_burn=2, + n_timesteps_synchronize=5, input_noise=0.01, ) model = train_reservoir_model(hyperparameters, train_tfdataset, val_tfdataset) diff --git a/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py b/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py index 5896e23fdd..b5327a7aec 100644 --- a/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py +++ b/workflows/diagnostics/fv3net/diagnostics/reservoir/compute.py @@ -59,7 +59,7 @@ def _load_batches(path, variables, nfiles): def _get_variables_to_load(model: ReservoirModel): - variables = list(model.input_variables) + variables = list(set(model.input_variables).union(model.output_variables)) if isinstance(model, HybridReservoirComputingModel): return variables + list(model.hybrid_variables) else: @@ -104,29 +104,27 @@ def _get_states_without_overlap( states_without_overlap_time_series = [] for var_time_series in states_with_overlap_time_series: # dims in array var_time_series are (t, x, y, z) - states_without_overlap_time_series.append( - var_time_series[:, overlap:-overlap, overlap:-overlap, :] - ) + if overlap > 0: + var_time_series = var_time_series[:, overlap:-overlap, overlap:-overlap, :] + states_without_overlap_time_series.append(var_time_series) # dims (t, var, x, y, z) return np.stack(states_without_overlap_time_series, axis=1) -def main(args): - model: ReservoirModel = fv3fit.load(args.reservoir_model_path) - with fsspec.open(args.validation_config_path, "r") as f: - val_data_config = yaml.safe_load(f) - val_batches = _load_batches( - path=val_data_config["url"], - variables=_get_variables_to_load(model), - nfiles=val_data_config.get("nfiles", None), - ) + +def generate_time_series( + model, + val_batches, + n_synchronize, +): + # Initialize hidden state model.reset_state() one_step_prediction_time_series = [] target_time_series = [] for batch_data in val_batches: - states_with_overlap_time_series = get_ordered_X( + input_states_with_overlap_time_series = get_ordered_X( batch_data, model.input_variables ) @@ -134,23 +132,38 @@ def main(args): hybrid_inputs_time_series = get_ordered_X( batch_data, model.hybrid_variables ) + hybrid_inputs_time_series = _get_states_without_overlap( hybrid_inputs_time_series, overlap=model.rank_divider.overlap ) else: hybrid_inputs_time_series = None batch_predictions = _get_predictions_over_batch( - model, states_with_overlap_time_series, hybrid_inputs_time_series + model, input_states_with_overlap_time_series, hybrid_inputs_time_series ) one_step_prediction_time_series += batch_predictions - target_time_series.append( - _get_states_without_overlap( - states_with_overlap_time_series, overlap=model.rank_divider.overlap + + if set(model.input_variables) == set(model.output_variables): + target_time_series.append( + _get_states_without_overlap( + input_states_with_overlap_time_series, + overlap=model.rank_divider.overlap, + ) ) - ) + else: + output_states_with_overlap_time_series = get_ordered_X( + batch_data, model.output_variables + ) + target_time_series.append( + _get_states_without_overlap( + output_states_with_overlap_time_series, + overlap=model.rank_divider.overlap, + ) + ) + target_time_series = np.concatenate(target_time_series, axis=0)[ - args.n_synchronize : + n_synchronize : ] persistence = target_time_series[:-1] @@ -158,8 +171,25 @@ def main(args): # _get_predictions_over_batch predicts up to n_timesteps-1 one_step_predictions = np.array(one_step_prediction_time_series)[ - args.n_synchronize : -1 + n_synchronize : -1 ] + return one_step_predictions, persistence, target + + +def main(args): + model: ReservoirModel = fv3fit.load(args.reservoir_model_path) + with fsspec.open(args.validation_config_path, "r") as f: + val_data_config = yaml.safe_load(f) + val_batches = _load_batches( + path=val_data_config["url"], + variables=_get_variables_to_load(model), + nfiles=val_data_config.get("nfiles", None), + ) + one_step_predictions, persistence, target = generate_time_series( + model=model, + val_batches=val_batches, + n_synchronize=args.n_synchronize, +) time_means_to_calculate = { "time_mean_prediction": one_step_predictions, "time_mean_error": one_step_predictions - target, @@ -169,7 +199,7 @@ def main(args): } diags_ = [] for key, data in time_means_to_calculate.items(): - diags_.append(_time_mean_dataset(model.input_variables, data, key)) + diags_.append(_time_mean_dataset(model.output_variables, data, key)) ds = xr.merge(diags_)