diff --git a/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py b/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py index 3e7a5477f9..d4c6b9e3e0 100644 --- a/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py +++ b/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py @@ -7,6 +7,7 @@ SECONDS_PER_DAY = 86400 TOLERANCE = 1.0e-12 +ML_STEPPER_NAMES = ["machine_learning", "reservoir_predictor"] logger = logging.getLogger(__name__) @@ -92,7 +93,16 @@ def _column_pq2(ds: xr.Dataset) -> xr.DataArray: def _column_dq1(ds: xr.Dataset) -> xr.DataArray: - if "net_heating_due_to_machine_learning" in ds: + + ml_col_heating_names = { + f"column_heating_due_to_{stepper}" for stepper in ML_STEPPER_NAMES + } + if len(ml_col_heating_names.intersection(set(ds.variables))) > 0: + column_dq1 = xr.zeros_like(ds.PRATEsfc) + for var in ml_col_heating_names: + if var in ds: + column_dq1 = column_dq1 + ds[var] + elif "net_heating_due_to_machine_learning" in ds: warnings.warn( "'net_heating_due_to_machine_learning' is a deprecated variable name. " "It will not be supported in future versions of fv3net. Use " @@ -110,8 +120,6 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray: ) # fix isochoric vs isobaric transition issue column_dq1 = 716.95 / 1004 * ds.net_heating - elif "column_heating_due_to_machine_learning" in ds: - column_dq1 = ds.column_heating_due_to_machine_learning elif "storage_of_internal_energy_path_due_to_machine_learning" in ds: column_dq1 = ds.storage_of_internal_energy_path_due_to_machine_learning else: @@ -125,8 +133,15 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray: def _column_dq2(ds: xr.Dataset) -> xr.DataArray: - if "net_moistening_due_to_machine_learning" in ds: - column_dq2 = SECONDS_PER_DAY * ds.net_moistening_due_to_machine_learning + + ml_col_moistening_names = { + f"net_moistening_due_to_{stepper}" for stepper in ML_STEPPER_NAMES + } + if len(ml_col_moistening_names.intersection(set(ds.variables))) > 0: + column_dq2 = xr.zeros_like(ds.PRATEsfc) + for var in ml_col_moistening_names: + if var in ds: + column_dq2 = column_dq2 + ds[var] elif "storage_of_specific_humidity_path_due_to_machine_learning" in ds: column_dq2 = ( SECONDS_PER_DAY diff --git a/workflows/prognostic_c48_run/runtime/diagnostics/compute.py b/workflows/prognostic_c48_run/runtime/diagnostics/compute.py index fd22e6d08c..358ac1ad30 100644 --- a/workflows/prognostic_c48_run/runtime/diagnostics/compute.py +++ b/workflows/prognostic_c48_run/runtime/diagnostics/compute.py @@ -24,8 +24,9 @@ def enforce_heating_and_moistening_tendency_constraints( timestep: float, hydrostatic: bool, mse_conserving: bool, - temperature_tendency_name="dQ1", - humidity_tendency_name="dQ2", + temperature_tendency_name: str = "dQ1", + humidity_tendency_name: str = "dQ2", + zero_fill_missing_tendencies: bool = False, ): temperature_tendency_initial = tendency.get( temperature_tendency_name, xr.zeros_like(state[SPHUM]) @@ -70,19 +71,34 @@ def enforce_heating_and_moistening_tendency_constraints( delp, "z", ) + elif zero_fill_missing_tendencies is True: + # Still need to output zeros if no tendency is predicted so that reservoir + # diagnostics are available on the updated timesteps + heating = xr.zeros_like(state[SPHUM]).isel(z=0).squeeze() + + try: heating = heating.assign_attrs( long_name="Change in ML column heating due to non-negative specific " - "humidity limiter" + "humidity limiter", + units="W/m**2", ) diagnostics_updates[ "column_integrated_dQ1_change_non_neg_sphum_constraint" ] = heating tendency_updates[temperature_tendency_name] = temperature_tendency_updated + except NameError: + logger.info( + "No heating tendency found, skipping heating due to limiter diagnostics" + ) - if "dQ2" in tendency: + if humidity_tendency_name in tendency: moistening = vcm.mass_integrate( humidity_tendency_updated - tendency[humidity_tendency_name], delp, dim="z", ) + + elif zero_fill_missing_tendencies is True: + moistening = xr.zeros_like(state[SPHUM]).isel(z=0).squeeze() + try: moistening = moistening.assign_attrs( units="kg/m^2/s", long_name="Change in ML column moistening due to non-negative specific " @@ -91,13 +107,16 @@ def enforce_heating_and_moistening_tendency_constraints( diagnostics_updates[ "column_integrated_dQ2_change_non_neg_sphum_constraint" ] = moistening - tendency_updates[humidity_tendency_name] = humidity_tendency_updated + except NameError: + logger.info( + "No moistening tendency found, skipping moistening due to " + "limiter diagnostics" + ) diagnostics_updates["specific_humidity_limiter_active"] = xr.where( humidity_tendency_initial != humidity_tendency_updated, 1, 0 ) - return tendency_updates, diagnostics_updates @@ -203,7 +222,7 @@ def compute_diagnostics( f"{TENDENCY_TO_STATE_NAME[k]}_tendency_due_to_nudging": v for k, v in tendency.items() } - elif label in {"machine_learning", "reservoir"}: + elif label in {"machine_learning", "reservoir_predictor"}: diags_3d = { "dQ1": temperature_tendency.assign_attrs(units="K/s").assign_attrs( description=f"air temperature tendency due to {label}" @@ -245,12 +264,12 @@ def compute_ml_momentum_diagnostics(state: State, tendency: State) -> Diagnostic ) -def rename_diagnostics(diags: Diagnostics): +def rename_diagnostics(diags: Diagnostics, label: str = "machine_learning"): """Postfix ML output names with _diagnostic and create zero-valued outputs in their stead. Function operates in place.""" ml_tendencies = { - "net_moistening_due_to_machine_learning", - "net_heating_due_to_machine_learning", + f"net_moistening_due_to_{label}", + f"net_heating_due_to_{label}", "column_integrated_dQu", "column_integrated_dQv", "override_for_time_adjusted_total_sky_downward_shortwave_flux_at_surface", diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 6eb3171b9f..fe046b8082 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -338,8 +338,12 @@ def _get_reservoir_stepper( ) -> Tuple[Optional[Stepper], Optional[Stepper]]: if config.reservoir_corrector is not None: res_config = config.reservoir_corrector + self._log_info("Getting reservoir steppers") incrementer, predictor = get_reservoir_steppers( - res_config, MPI.COMM_WORLD.Get_rank(), init_time=init_time + res_config, + MPI.COMM_WORLD.Get_rank(), + init_time=init_time, + model_timestep=self._timestep, ) else: incrementer, predictor = None, None @@ -574,12 +578,43 @@ def _increment_reservoir(self) -> Diagnostics: return {} def _apply_reservoir_update_to_state(self) -> Diagnostics: - # TODO: handle tendencies + # TODO: handle tendencies. Currently the returned tendencies + # are only used for diagnostics and are not used in updating state if self._reservoir_predict_stepper is not None: - [_, diags, state] = self._reservoir_predict_stepper( - self._state.time, self._state + [ + tendencies_from_state_prediction, + diags, + state_updates, + ] = self._reservoir_predict_stepper(self._state.time, self._state) + ( + stepper_diags, + net_moistening, + ) = self._reservoir_predict_stepper.get_diagnostics( + self._state, tendencies_from_state_prediction + ) + diags.update(stepper_diags) + if self._reservoir_predict_stepper.diagnostic is True: # type: ignore + rename_diagnostics(diags, label="reservoir_predictor") + + state_updates[TOTAL_PRECIP] = precipitation_sum( + self._state[TOTAL_PRECIP], net_moistening, self._timestep, ) - self._state.update_mass_conserving(state) + + self._state.update_mass_conserving(state_updates) + + diags.update({name: self._state[name] for name in self._states_to_output}) + diags.update( + { + "area": self._state[AREA], + "cnvprcp_after_python": self._fv3gfs.get_diagnostic_by_name( + "cnvprcp" + ).data_array, + TOTAL_PRECIP_RATE: precipitation_rate( + self._state[TOTAL_PRECIP], self._timestep + ), + } + ) + return diags else: return {} diff --git a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py index cae62868c3..a41abb27ff 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py +++ b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py @@ -17,7 +17,12 @@ import fv3fit from fv3fit._shared.halos import append_halos_using_mpi from fv3fit.reservoir.adapters import ReservoirDatasetAdapter -from runtime.names import SST +from runtime.names import SST, SPHUM, TEMP +from runtime.tendency import add_tendency, tendencies_from_state_updates +from runtime.diagnostics import ( + enforce_heating_and_moistening_tendency_constraints, + compute_diagnostics, +) from .prescriber import sst_update_from_reference from .machine_learning import rename_dataset_members, NameDict @@ -43,6 +48,10 @@ class ReservoirConfig: warm_start: Whether to use the saved state from a pre-synced reservoir rename_mapping: mapping from field names used in the underlying reservoir model to names used in fv3gfs wrapper + hydrostatic (optional): whether simulation is hydrostatic. + For net heating diagnostic. Defaults to false. + mse_conserving_limiter (optional): whether to use MSE-conserving humidity + limiter. Defaults to false. """ models: Mapping[int, str] @@ -52,6 +61,8 @@ class ReservoirConfig: diagnostic_only: bool = False warm_start: bool = False rename_mapping: NameDict = dataclasses.field(default_factory=dict) + hydrostatic: bool = False + mse_conserving_limiter: bool = False class _FiniteStateMachine: @@ -150,20 +161,26 @@ def __init__( model: ReservoirDatasetAdapter, init_time: cftime.DatetimeJulian, reservoir_timestep: timedelta, + model_timestep: float, synchronize_steps: int, state_machine: Optional[_FiniteStateMachine] = None, diagnostic_only: bool = False, input_averager: Optional[TimeAverageInputs] = None, rename_mapping: Optional[NameDict] = None, warm_start: bool = False, + hydrostatic: bool = False, + mse_conserving_limiter: bool = False, ): self.model = model self.synchronize_steps = synchronize_steps self.initial_time = init_time self.timestep = reservoir_timestep + self.model_timestep = model_timestep self.diagnostic = diagnostic_only self.input_averager = input_averager self.warm_start = warm_start + self.hydrostatic = hydrostatic + self.mse_conserving_limiter = mse_conserving_limiter if state_machine is None: state_machine = _FiniteStateMachine() @@ -264,7 +281,6 @@ def __call__(self, time, state): logger.info(f"Incrementing rc at time {time}") self.increment_reservoir(inputs) - diags = rename_dataset_members( inputs, {k: f"{self.rename_mapping.get(k, k)}_rc_in" for k in inputs} ) @@ -290,6 +306,7 @@ class ReservoirPredictStepper(_ReservoirStepper): """ label = "reservoir_predictor" + DIAGS_OUTPUT_SUFFIX = "rc_out" def predict(self, inputs, state): """Called at the end of timeloop after time has ticked from t -> t+1""" @@ -299,7 +316,7 @@ def predict(self, inputs, state): output_state = rename_dataset_members(result, self.rename_mapping) diags = rename_dataset_members( - output_state, {k: f"{k}_rc_out" for k in output_state} + output_state, {k: f"{k}_{self.DIAGS_OUTPUT_SUFFIX}" for k in output_state} ) for k, v in output_state.items(): @@ -347,15 +364,59 @@ def __call__(self, time, state): logger.info(f"Reservoir model predict at time {time}") if self.input_averager is not None: inputs.update(self.input_averager.get_averages()) - tendencies, diags, state = self.predict(inputs, state) + + tendencies, diags, updated_state = self.predict(inputs, state) + hybrid_diags = rename_dataset_members( inputs, {k: f"{self.rename_mapping.get(k, k)}_hyb_in" for k in inputs} ) diags.update(hybrid_diags) + + # This check is done on the _rc_out diags since those are always available. + # This allows zero field diags to be returned on timesteps where the + # reservoir is not updating the state. + diags_Tq_vars = {f"{v}_{self.DIAGS_OUTPUT_SUFFIX}" for v in [TEMP, SPHUM]} + + if diags_Tq_vars.issubset(list(diags.keys())): + # TODO: Currently the reservoir only predicts updated states and returns + # empty tendencies. If tendency predictions are implemented in the + # prognostic run, the limiter/conservation updates should be updated to + # take this option into account and use predicted tendencies directly. + tendencies_from_state_prediction = tendencies_from_state_updates( + initial_state=state, + updated_state=updated_state, + dt=self.model_timestep, + ) + ( + tendency_updates_from_constraints, + diagnostics_updates_from_constraints, + ) = enforce_heating_and_moistening_tendency_constraints( + state=state, + tendency=tendencies_from_state_prediction, + timestep=self.model_timestep, + mse_conserving=self.mse_conserving_limiter, + hydrostatic=self.hydrostatic, + temperature_tendency_name="dQ1", + humidity_tendency_name="dQ2", + zero_fill_missing_tendencies=True, + ) + + diags.update(diagnostics_updates_from_constraints) + updated_state = add_tendency( + state=state, + tendencies=tendency_updates_from_constraints, + dt=self.model_timestep, + ) + tendencies.update(tendency_updates_from_constraints) + else: - tendencies, diags, state = {}, {}, {} + tendencies, diags, updated_state = {}, {}, {} + + return tendencies, diags, updated_state - return tendencies, diags, state + def get_diagnostics(self, state, tendency): + diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic) + return diags, diags[f"net_moistening_due_to_{self.label}"] def open_rc_model(path: str) -> ReservoirDatasetAdapter: @@ -379,7 +440,10 @@ def _get_time_averagers(model, do_time_average): def get_reservoir_steppers( - config: ReservoirConfig, rank: int, init_time: cftime.DatetimeJulian, + config: ReservoirConfig, + rank: int, + init_time: cftime.DatetimeJulian, + model_timestep: float, ): """ Gets both steppers needed by the time loop to increment the state using @@ -402,22 +466,26 @@ def get_reservoir_steppers( incrementer = ReservoirIncrementOnlyStepper( model, init_time, - rc_tdelta, - config.synchronize_steps, + reservoir_timestep=rc_tdelta, + synchronize_steps=config.synchronize_steps, state_machine=state_machine, input_averager=increment_averager, rename_mapping=config.rename_mapping, warm_start=config.warm_start, + model_timestep=model_timestep, ) predictor = ReservoirPredictStepper( model, init_time, - rc_tdelta, - config.synchronize_steps, + reservoir_timestep=rc_tdelta, + synchronize_steps=config.synchronize_steps, state_machine=state_machine, diagnostic_only=config.diagnostic_only, input_averager=predict_averager, rename_mapping=config.rename_mapping, warm_start=config.warm_start, + model_timestep=model_timestep, + hydrostatic=config.hydrostatic, + mse_conserving_limiter=config.mse_conserving_limiter, ) return incrementer, predictor diff --git a/workflows/prognostic_c48_run/runtime/tendency.py b/workflows/prognostic_c48_run/runtime/tendency.py index 2bf80a8511..90b32ea4ec 100644 --- a/workflows/prognostic_c48_run/runtime/tendency.py +++ b/workflows/prognostic_c48_run/runtime/tendency.py @@ -11,11 +11,34 @@ NORTHWARD_WIND_TENDENCY, X_WIND_TENDENCY, Y_WIND_TENDENCY, + STATE_NAME_TO_TENDENCY, ) -from runtime.types import State +from runtime.types import State, Tendencies from toolz import dissoc +def tendencies_from_state_updates( + initial_state: State, updated_state: State, dt: float +) -> Tendencies: + """Compute tendencies given intial and updated states + + Args: + initial_state: initial state + updated_state: updated state + variables: variables to compute tendencies for + + Returns: + tendencies: tendencies computed from state updates + """ + tendencies = {} + for variable in updated_state: + tendency_var = STATE_NAME_TO_TENDENCY[variable] + tendencies[tendency_var] = ( + updated_state[variable] - initial_state[variable] + ) / dt + return tendencies + + def state_updates_from_tendency(tendency_updates): # Prescriber can overwrite the state updates predicted by ML tendencies # Sometimes this is desired and we want to save both the overwritten updated state diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out index 936f3213d5..65af8316d4 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out @@ -435,6 +435,7 @@ prephysics: null radiation_scheme: null reservoir_corrector: diagnostic_only: false + hydrostatic: false models: 0: gs://vcm-ml-scratch/rc-model-tile-0 1: gs://vcm-ml-scratch/rc-model-tile-1 @@ -442,6 +443,7 @@ reservoir_corrector: 3: gs://vcm-ml-scratch/rc-model-tile-3 4: gs://vcm-ml-scratch/rc-model-tile-4 5: gs://vcm-ml-scratch/rc-model-tile-5 + mse_conserving_limiter: false rename_mapping: {} reservoir_timestep: 900s synchronize_steps: 12 diff --git a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py index 1dc7666825..900e50f02b 100644 --- a/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py +++ b/workflows/prognostic_c48_run/tests/test_reservoir_stepper.py @@ -16,6 +16,8 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock +MODEL_TIMESTEP = 900 + def test_reservoir_stepper_state(): fsm = _FiniteStateMachine() @@ -129,6 +131,7 @@ def get_mock_ReservoirSteppers(): model, datetime(2020, 1, 1, 0, 0, 0), timedelta(minutes=10), + MODEL_TIMESTEP, 2, state_machine=state_machine, ) @@ -137,6 +140,7 @@ def get_mock_ReservoirSteppers(): model, datetime(2020, 1, 1, 0, 0, 0), timedelta(minutes=10), + MODEL_TIMESTEP, 2, state_machine=state_machine, ) @@ -212,7 +216,9 @@ def test_get_reservoir_steppers(patched_reservoir_module): config = ReservoirConfig({0: "model"}, 0, reservoir_timestep="10m") time = datetime(2020, 1, 1, 0, 0, 0) - incrementer, predictor = reservoir.get_reservoir_steppers(config, 0, time) + incrementer, predictor = reservoir.get_reservoir_steppers( + config, 0, time, MODEL_TIMESTEP + ) # Check that both steppers share model and state machine objects assert incrementer.model is predictor.model @@ -230,7 +236,9 @@ def test_reservoir_steppers_state_machine_constraint(patched_reservoir_module): config = ReservoirConfig({0: "model"}, 0, reservoir_timestep="10m") time = datetime(2020, 1, 1, 0, 0, 0) - incrementer, predictor = reservoir.get_reservoir_steppers(config, 0, time) + incrementer, predictor = reservoir.get_reservoir_steppers( + config, 0, time, MODEL_TIMESTEP + ) # check that steppers respect state machine limit state = MockState(a=xr.DataArray(np.ones(1), dims=["x"])) @@ -247,7 +255,9 @@ def test_reservoir_steppers_with_interval_averaging(patched_reservoir_module): {0: "model"}, 0, reservoir_timestep="30m", time_average_inputs=True ) init_time = datetime(2020, 1, 1, 0, 0, 0) - incrementer, predictor = reservoir.get_reservoir_steppers(config, 0, init_time) + incrementer, predictor = reservoir.get_reservoir_steppers( + config, 0, init_time, MODEL_TIMESTEP + ) state = MockState(a=xr.DataArray(np.ones(1), dims=["x"])) incrementer(init_time, state) @@ -262,7 +272,9 @@ def test_reservoir_steppers_diagnostic_only(patched_reservoir_module): {0: "model"}, 0, reservoir_timestep="10m", diagnostic_only=True ) init_time = datetime(2020, 1, 1, 0, 0, 0) - incrementer, predictor = reservoir.get_reservoir_steppers(config, 0, init_time) + incrementer, predictor = reservoir.get_reservoir_steppers( + config, 0, init_time, MODEL_TIMESTEP + ) state = MockState(a=xr.DataArray(np.ones(1), dims=["x"])) incrementer(init_time, state) @@ -276,7 +288,9 @@ def test_reservoir_steppers_renaming(patched_reservoir_module): {0: "model"}, 0, reservoir_timestep="10m", rename_mapping={"a": "b"} ) init_time = datetime(2020, 1, 1, 0, 0, 0) - incrementer, predictor = reservoir.get_reservoir_steppers(config, 0, init_time) + incrementer, predictor = reservoir.get_reservoir_steppers( + config, 0, init_time, MODEL_TIMESTEP + ) res_input = MockState(b=xr.DataArray(np.ones(3), dims=["x"])) # different dimension to test diagnostics dims renaming @@ -290,4 +304,6 @@ def test_reservoir_steppers_renaming(patched_reservoir_module): def test_model_paths_and_rank_index_mismatch_on_load(): config = ReservoirConfig({1: "model"}, 0, reservoir_timestep="10m") with pytest.raises(KeyError): - reservoir.get_reservoir_steppers(config, 1, datetime(2020, 1, 1)) + reservoir.get_reservoir_steppers( + config, 1, datetime(2020, 1, 1), MODEL_TIMESTEP + )