diff --git a/external/fv3fit/fv3fit/reservoir/_reshaping.py b/external/fv3fit/fv3fit/reservoir/_reshaping.py new file mode 100644 index 0000000000..d6e2e3c55c --- /dev/null +++ b/external/fv3fit/fv3fit/reservoir/_reshaping.py @@ -0,0 +1,6 @@ +import numpy as np + + +def flatten_2d_keeping_columns_contiguous(arr: np.ndarray): + # ex. [[1,2],[3,4], [5,6]] -> [1,3,5,2,4,6] + return np.reshape(arr, -1, "F") diff --git a/external/fv3fit/fv3fit/reservoir/domain.py b/external/fv3fit/fv3fit/reservoir/domain.py index a6e070587a..3ec635efcf 100644 --- a/external/fv3fit/fv3fit/reservoir/domain.py +++ b/external/fv3fit/fv3fit/reservoir/domain.py @@ -221,3 +221,27 @@ def assure_same_dims(variable_tensors: Iterable[tf.Tensor]) -> Iterable[tf.Tenso f"have either {max_dims} or {max_dims-1}." ) return reshaped_tensors + + +def merge_subdomains(flat_prediction, rank_divider, latent_dims): + subdomain_columns = flat_prediction.reshape(-1, rank_divider.n_subdomains) + d_ = [] + for s in range(rank_divider.n_subdomains): + subdomain_prediction = rank_divider.unstack_subdomain( + np.array([subdomain_columns[:, s]]), with_overlap=False + ) + d_.append(subdomain_prediction[0]) + + domain = [] + subdomain_without_overlap_shape = ( + rank_divider.subdomain_xy_size_without_overlap, + rank_divider.subdomain_xy_size_without_overlap, + ) + + for z in range(latent_dims): + domain_z_blocks = np.array(d_)[:, :, :, z].reshape( + *rank_divider.subdomain_layout, *subdomain_without_overlap_shape + ) + domain_z = np.concatenate(np.concatenate(domain_z_blocks, axis=1), axis=-1) + domain.append(domain_z) + return np.stack(np.array(domain), axis=0).transpose(1, 2, 0) diff --git a/external/fv3fit/fv3fit/reservoir/model.py b/external/fv3fit/fv3fit/reservoir/model.py index 51405c25e3..9fd1112e03 100644 --- a/external/fv3fit/fv3fit/reservoir/model.py +++ b/external/fv3fit/fv3fit/reservoir/model.py @@ -1,15 +1,16 @@ import fsspec -from fv3fit.reservoir.readout import ReservoirComputingReadout import os from typing import Optional, Iterable, Hashable import yaml from fv3fit import Predictor +from .readout import ReservoirComputingReadout from .reservoir import Reservoir from .domain import RankDivider from fv3fit._shared import io from .utils import square_even_terms from .autoencoder import Autoencoder +from ._reshaping import flatten_2d_keeping_columns_contiguous @io.register("pure-reservoir") @@ -60,10 +61,8 @@ def predict(self): readout_input = self.reservoir.state # For prediction over multiple subdomains (>1 column in reservoir state # array), flatten state into 1D vector before predicting - readout_input = readout_input.reshape(-1) - + readout_input = flatten_2d_keeping_columns_contiguous(readout_input) prediction = self.readout.predict(readout_input).reshape(-1) - return prediction def reset_state(self): diff --git a/external/fv3fit/fv3fit/reservoir/readout.py b/external/fv3fit/fv3fit/reservoir/readout.py index 35ac9cf2b0..fd265a32de 100644 --- a/external/fv3fit/fv3fit/reservoir/readout.py +++ b/external/fv3fit/fv3fit/reservoir/readout.py @@ -63,6 +63,10 @@ def get_weights(self): coefficients, intercepts = W[:-1, :], W[-1, :] return coefficients, intercepts + def predict(self, X): + coefficients, intercepts = self.get_weights() + return np.dot(X, coefficients) + intercepts + class ReservoirComputingReadout: """Readout layer of the reservoir computing model @@ -113,7 +117,6 @@ def combine_readouts(readouts: Sequence[ReservoirComputingReadout]): # Concatenate the intercepts of individual readouts into single array combined_intercepts = np.concatenate(intercepts) - return ReservoirComputingReadout( coefficients=combined_coefficients, intercepts=combined_intercepts, ) diff --git a/external/fv3fit/fv3fit/reservoir/train.py b/external/fv3fit/fv3fit/reservoir/train.py index f0c8e8ca1a..66263d8eb6 100644 --- a/external/fv3fit/fv3fit/reservoir/train.py +++ b/external/fv3fit/fv3fit/reservoir/train.py @@ -15,7 +15,14 @@ ReservoirComputingReadout, ) from .readout import combine_readouts -from .domain import RankDivider, stack_time_series_samples, assure_same_dims +from .domain import ( + RankDivider, + stack_time_series_samples, + assure_same_dims, + merge_subdomains, +) +import wandb + logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -47,6 +54,64 @@ def _get_ordered_X(X_mapping, variables): return assure_same_dims(ordered_tensors) +def _decode_columns(data, decoder): + # differs from encode_columns as the decoder can predict multiple outputs + # rather than a single latent vector + # expands a sequnence of N x M x L dim data into i variables + # to one or more N x M x Vi dim array, where Vi is number of features + # (usually vertical levels) of each variable and L << V is a smaller number + # of latent dimensions + reshaped = [_stack_array_preserving_last_dim(var) for var in data] + decoded_reshaped = decoder.predict(reshaped) + original_2d_shape = data[0].shape[:-1] + decoded_data = [] + for i, var_data in enumerate(decoded_reshaped): + decoded_data.append(decoded_reshaped[i].reshape(*original_2d_shape, -1)) + return decoded_data + + +def validation_single_timestep(validation_batches, model, n_batches_burn): + for b, batch_data in enumerate(validation_batches): + if b < n_batches_burn: + logger.info(f"Synchronizing on batch {b+1}") + time_series_with_overlap, time_series_without_overlap = _process_batch_data( + variables=model.input_variables, + batch_data=batch_data, + rank_divider=model.rank_divider, + autoencoder=model.autoencoder, + ) + else: + X = _get_ordered_X(batch_data, model.input_variables) + truth = [] + overlap = model.rank_divider.overlap + for var_data in X: + last_timestep_in_batch = var_data[0] + truth.append( + last_timestep_in_batch[overlap:-overlap, overlap:-overlap, :] + ) + + flat_prediction = model.predict() + subdomain_predictions_latent_space = merge_subdomains( + flat_prediction, model.rank_divider, model.autoencoder.n_latent_dims + ) + prediction = _decode_columns( + [subdomain_predictions_latent_space], model.autoencoder.decoder + ) + truth = np.array(truth) + prediction = np.array(prediction) + val_log = { + "truth": truth, + "prediction": prediction, + } + wandb.log( + { + "validation_single_timestep": val_log, + "val_loss": ((truth - prediction) ** 2).mean(), + } + ) + return + + @register_training_function("pure-reservoir", ReservoirTrainingConfig) def train_reservoir_model( hyperparameters: ReservoirTrainingConfig, @@ -136,7 +201,15 @@ def train_reservoir_model( readout=readout, square_half_hidden_state=hyperparameters.square_half_hidden_state, rank_divider=rank_divider, + autoencoder=autoencoder, ) + + if validation_batches is not None and wandb.run is not None: + logger.info("Single timestep validation") + validation_single_timestep( + validation_batches, model, hyperparameters.n_batches_burn + ) + return model diff --git a/external/fv3fit/tests/reservoir/test__reshaping.py b/external/fv3fit/tests/reservoir/test__reshaping.py new file mode 100644 index 0000000000..22f8d4cdf9 --- /dev/null +++ b/external/fv3fit/tests/reservoir/test__reshaping.py @@ -0,0 +1,9 @@ +import numpy as np +from fv3fit.reservoir._reshaping import flatten_2d_keeping_columns_contiguous + + +def test_flatten_2d_keeping_columns_contiguous(): + x = np.array([[1, 2], [3, 4], [5, 6]]) + np.testing.assert_array_equal( + flatten_2d_keeping_columns_contiguous(x), np.array([1, 3, 5, 2, 4, 6]) + ) diff --git a/projects/reservoir/.envrc b/projects/reservoir/.envrc new file mode 100644 index 0000000000..60cf334f99 --- /dev/null +++ b/projects/reservoir/.envrc @@ -0,0 +1 @@ +export WANDB_PROJECT='reservoir-training' \ No newline at end of file diff --git a/projects/reservoir/fv3/save_ranks.py b/projects/reservoir/fv3/save_ranks.py index 4b2720c602..f92ee5f089 100644 --- a/projects/reservoir/fv3/save_ranks.py +++ b/projects/reservoir/fv3/save_ranks.py @@ -5,7 +5,7 @@ from tempfile import NamedTemporaryFile import toolz -from .cubed_sphere import CubedSphereDivider +from cubed_sphere import CubedSphereDivider import vcm logging.basicConfig() @@ -64,6 +64,7 @@ def _get_parser() -> argparse.ArgumentParser: help=("Number of timesteps to save per rank netcdf."), ) parser.add_argument("--variables", type=str, nargs="+", default=[]) + parser.add_argument("--ranks", type=int, nargs="+", default=None) return parser @@ -118,11 +119,13 @@ def get_ordered_dims_extent(dims: dict): else: time_chunks = [list(data_times)] + ranks = args.ranks or range(cubedsphere_divider.total_ranks) for t, time_chunk in enumerate(time_chunks): data_time_slice = data.sel(time=time_chunk).load() - for r in range(cubedsphere_divider.total_ranks): + for r in ranks: output_dir = os.path.join(args.output_path, f"rank_{r}") - rank_output_path = os.path.join(output_dir, f"{t}.nc") + file_str = f"0{t}" if t < 10 else f"{t}" + rank_output_path = os.path.join(output_dir, f"{file_str}.nc") rank_data = cubedsphere_divider.get_rank_data( data_time_slice, rank=r, overlap=args.overlap ) diff --git a/projects/reservoir/fv3/test_save_ranks.sh b/projects/reservoir/fv3/test_save_ranks.sh new file mode 100755 index 0000000000..c904addd56 --- /dev/null +++ b/projects/reservoir/fv3/test_save_ranks.sh @@ -0,0 +1,35 @@ +#!/bin/bash + + +python -m save_ranks \ + gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \ + gs://vcm-ml-scratch/annak/2023-02-27/rank_data/ \ + 2 \ + 2 \ + --stop-time 20180815.000000 \ + --variables air_temperature specific_humidity \ + --time-chunks 40 + + +python -m save_ranks \ + gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \ + gs://vcm-ml-experiments/reservoir-computing-offline/data/n2f-25km/train/start_20190215_end_20190615 \ + 2 \ + 2 \ + --start-time 20190215.000000 \ + --stop-time 20190615.000000 \ + --variables air_temperature specific_humidity \ + --time-chunks 40 \ + --ranks 0 1 + + +python -m save_ranks \ + gs://vcm-ml-experiments/spencerc/2022-01-19/n2f-25km-unperturbed-snoalb/fv3gfs_run/state_after_timestep.zarr \ + gs://vcm-ml-experiments/reservoir-computing-offline/data/n2f-25km/val/start_20190615_end_2019_0715 \ + 2 \ + 2 \ + --start-time 20190615.000000 \ + --stop-time 20190715.000000 \ + --variables air_temperature specific_humidity \ + --time-chunks 40 \ + --ranks 0 1 diff --git a/projects/reservoir/fv3/train.sh b/projects/reservoir/fv3/train.sh new file mode 100755 index 0000000000..c89c88d4cf --- /dev/null +++ b/projects/reservoir/fv3/train.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +python -m fv3fit.train \ + /home/AnnaK/fv3net/projects/reservoir/fv3/train_config.yaml \ + /home/AnnaK/fv3net/projects/reservoir/fv3/train_data.yaml \ + gs://vcm-ml-scratch/annak/2023-04-19/persistence_rc_no_encoder_T \ + --validation-data-config /home/AnnaK/fv3net/projects/reservoir/fv3/train_data.yaml \ + --no-wandb \ No newline at end of file diff --git a/projects/reservoir/fv3/train_config.yaml b/projects/reservoir/fv3/train_config.yaml new file mode 100644 index 0000000000..67319cda00 --- /dev/null +++ b/projects/reservoir/fv3/train_config.yaml @@ -0,0 +1,35 @@ + +model_type: pure-reservoir +hyperparameters: + n_jobs: 4 + #autoencoder_path: gs://vcm-ml-experiments/reservoir-computing-offline/2023-03-29/dense-autoencoder-train-full-year/trained_models/dense_autoencoder/autoencoder + # autoencoder_path: gs://vcm-ml-scratch/annak/2023-03-03/trained_autoencoder + seed: 0 + input_variables: + - air_temperature + #- specific_humidity + output_variables: + - air_temperature + #- specific_humidity + subdomain: + layout: + - 6 + - 6 + overlap: 2 + rank_dims: + - time + - x + - y + - z + reservoir_hyperparameters: + state_size: 6000 + adjacency_matrix_sparsity: 0.9 + spectral_radius: 0.7 + seed: 0 + input_coupling_sparsity: 0 + input_coupling_scaling: 0.0001 + readout_hyperparameters: + l2: 0.05 + n_batches_burn: 1 + input_noise: 0.01 + square_half_hidden_state: True diff --git a/projects/reservoir/fv3/train_data.yaml b/projects/reservoir/fv3/train_data.yaml new file mode 100644 index 0000000000..e7c745a7c0 --- /dev/null +++ b/projects/reservoir/fv3/train_data.yaml @@ -0,0 +1,11 @@ + +#url: gs://vcm-ml-scratch/annak/2023-02-27/rank_data/rank_0 +url: gs://vcm-ml-scratch/annak/2023-04-19/persistence_netcdfs #gs://vcm-ml-scratch/annak/2023-04-19/train #gs://vcm-ml-scratch/annak/2023-02-22/rank_data/rank_1 +dim_order: + - time + - x + - y + - z +varying_first_dim: True +sort_files: True +shuffle: False \ No newline at end of file diff --git a/projects/reservoir/fv3/val_data.yaml b/projects/reservoir/fv3/val_data.yaml new file mode 100644 index 0000000000..19b513cdb2 --- /dev/null +++ b/projects/reservoir/fv3/val_data.yaml @@ -0,0 +1,10 @@ + +#url: gs://vcm-ml-scratch/annak/2023-02-27/rank_data/rank_0 +url: gs://vcm-ml-scratch/annak/2023-04-19/persistence_netcdfs +dim_order: + - time + - x + - y + - z +varying_first_dim: True +sort_files: True \ No newline at end of file