diff --git a/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py b/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py index 3e7a5477f9..d4c6b9e3e0 100644 --- a/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py +++ b/workflows/diagnostics/fv3net/diagnostics/prognostic_run/derived_variables.py @@ -7,6 +7,7 @@ SECONDS_PER_DAY = 86400 TOLERANCE = 1.0e-12 +ML_STEPPER_NAMES = ["machine_learning", "reservoir_predictor"] logger = logging.getLogger(__name__) @@ -92,7 +93,16 @@ def _column_pq2(ds: xr.Dataset) -> xr.DataArray: def _column_dq1(ds: xr.Dataset) -> xr.DataArray: - if "net_heating_due_to_machine_learning" in ds: + + ml_col_heating_names = { + f"column_heating_due_to_{stepper}" for stepper in ML_STEPPER_NAMES + } + if len(ml_col_heating_names.intersection(set(ds.variables))) > 0: + column_dq1 = xr.zeros_like(ds.PRATEsfc) + for var in ml_col_heating_names: + if var in ds: + column_dq1 = column_dq1 + ds[var] + elif "net_heating_due_to_machine_learning" in ds: warnings.warn( "'net_heating_due_to_machine_learning' is a deprecated variable name. " "It will not be supported in future versions of fv3net. Use " @@ -110,8 +120,6 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray: ) # fix isochoric vs isobaric transition issue column_dq1 = 716.95 / 1004 * ds.net_heating - elif "column_heating_due_to_machine_learning" in ds: - column_dq1 = ds.column_heating_due_to_machine_learning elif "storage_of_internal_energy_path_due_to_machine_learning" in ds: column_dq1 = ds.storage_of_internal_energy_path_due_to_machine_learning else: @@ -125,8 +133,15 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray: def _column_dq2(ds: xr.Dataset) -> xr.DataArray: - if "net_moistening_due_to_machine_learning" in ds: - column_dq2 = SECONDS_PER_DAY * ds.net_moistening_due_to_machine_learning + + ml_col_moistening_names = { + f"net_moistening_due_to_{stepper}" for stepper in ML_STEPPER_NAMES + } + if len(ml_col_moistening_names.intersection(set(ds.variables))) > 0: + column_dq2 = xr.zeros_like(ds.PRATEsfc) + for var in ml_col_moistening_names: + if var in ds: + column_dq2 = column_dq2 + ds[var] elif "storage_of_specific_humidity_path_due_to_machine_learning" in ds: column_dq2 = ( SECONDS_PER_DAY diff --git a/workflows/prognostic_c48_run/runtime/loop.py b/workflows/prognostic_c48_run/runtime/loop.py index 1c1962c298..1fa8414162 100644 --- a/workflows/prognostic_c48_run/runtime/loop.py +++ b/workflows/prognostic_c48_run/runtime/loop.py @@ -586,19 +586,32 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: diags, state_updates, ) = self._reservoir_predict_stepper(self._state.time, self._state) + + logger.info(f"Reservoir stepper diagnostics: {list(diags.keys())}") + logger.info( + f"Reservoir stepper state updates: {list(state_updates.keys())}" + ) + + if self._reservoir_predict_stepper.is_diagnostic: # type: ignore + rename_diagnostics(diags, label="reservoir_predictor") + ( - stepper_diags, - net_moistening, + diags_from_tendencies, + _, ) = self._reservoir_predict_stepper.get_diagnostics( self._state, tendencies_from_state_prediction ) - diags.update(stepper_diags) - if self._reservoir_predict_stepper.is_diagnostic: # type: ignore - rename_diagnostics(diags, label="reservoir_predictor") + diags.update(diags_from_tendencies) - state_updates[TOTAL_PRECIP] = precipitation_sum( - self._state[TOTAL_PRECIP], net_moistening, self._timestep, + net_moistening_due_to_reservoir_adjustment = diags.get( + "net_moistening_due_to_reservoir_adjustment", + xr.zeros_like(self._state[TOTAL_PRECIP]), ) + precip = self._reservoir_predict_stepper.update_precip( # type: ignore + self._state[TOTAL_PRECIP], net_moistening_due_to_reservoir_adjustment, + ) + diags.update(precip) + state_updates[TOTAL_PRECIP] = precip[TOTAL_PRECIP] self._state.update_mass_conserving(state_updates) @@ -609,9 +622,7 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics: "cnvprcp_after_python": self._fv3gfs.get_diagnostic_by_name( "cnvprcp" ).data_array, - TOTAL_PRECIP_RATE: precipitation_rate( - self._state[TOTAL_PRECIP], self._timestep - ), + TOTAL_PRECIP_RATE: precip["total_precip_rate_res_interval_avg"], } ) diff --git a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py index 6daba607a3..ad02741d1d 100644 --- a/workflows/prognostic_c48_run/runtime/steppers/reservoir.py +++ b/workflows/prognostic_c48_run/runtime/steppers/reservoir.py @@ -17,7 +17,7 @@ import fv3fit from fv3fit._shared.halos import append_halos_using_mpi from fv3fit.reservoir.adapters import ReservoirDatasetAdapter -from runtime.names import SST, SPHUM, TEMP +from runtime.names import SST, SPHUM, TEMP, PHYSICS_PRECIP_RATE, TOTAL_PRECIP from runtime.tendency import add_tendency, tendencies_from_state_updates from runtime.diagnostics import ( enforce_heating_and_moistening_tendency_constraints, @@ -63,6 +63,7 @@ class ReservoirConfig: rename_mapping: NameDict = dataclasses.field(default_factory=dict) hydrostatic: bool = False mse_conserving_limiter: bool = False + interval_average_precipitation: bool = False class _FiniteStateMachine: @@ -104,6 +105,75 @@ def __call__(self, state: str): ) +class TendencyPrecipTracker: + def __init__(self, reservoir_timestep_seconds: float): + self.reservoir_timestep_seconds = reservoir_timestep_seconds + self.physics_precip_averager = TimeAverageInputs([PHYSICS_PRECIP_RATE]) + self._air_temperature_at_previous_interval = None + self._specific_humidity_at_previous_interval = None + + def increment_physics_precip_rate(self, physics_precip_rate): + self.physics_precip_averager.increment_running_average( + {PHYSICS_PRECIP_RATE: physics_precip_rate} + ) + + def average_physics_precip_rate(self): + return self.physics_precip_average.get_averages()[PHYSICS_PRECIP_RATE] + + def update_tracked_state(self, air_temperature, specific_humidity): + self._air_temperature_at_previous_interval = air_temperature + self._specific_humidity_at_previous_interval = specific_humidity + + def calculate_tendencies(self, air_temperature, specific_humidity): + if ( + self._specific_humidity_at_previous_interval is None + or self._air_temperature_at_previous_interval is None + ): + logger.info( + "Previous reservoir prediction of specific_humidity and " + "air_temperature not saved. Returning zero tendencies" + ) + dQ1, dQ2 = xr.zeros_like(air_temperature), xr.zeros_like(air_temperature) + else: + dQ1 = ( + air_temperature - self._air_temperature_at_previous_interval + ) / self.reservoir_timestep_seconds + dQ2 = ( + specific_humidity - self._specific_humidity_at_previous_interval + ) / self.reservoir_timestep_seconds + return {"dQ1": dQ1, "dQ2": dQ2} + + def interval_avg_precip_rates(self, net_moistening_due_to_reservoir): + physics_precip_rate = self.physics_precip_averager.get_averages()[ + PHYSICS_PRECIP_RATE + ] + total_precip_rate = physics_precip_rate - net_moistening_due_to_reservoir + total_precip_rate = total_precip_rate.where(total_precip_rate >= 0, 0) + reservoir_precip_rate = total_precip_rate - physics_precip_rate + return { + "total_precip_rate_res_interval_avg": total_precip_rate, + "physics_precip_rate_res_interval_avg": physics_precip_rate, + "reservoir_precip_rate_res_interval_avg": reservoir_precip_rate, + } + + def accumulated_precip_update( + self, + physics_precip_total_over_model_timestep, + reservoir_precip_rate_over_res_interval, + reservoir_timestep, + ): + # Since the reservoir correction is only applied every reservoir_timestep, + # all of the precip due to the reservoir is put into the accumulated precip + # in the model timestep at update time. + m_per_mm = 1 / 1000 + reservoir_total_precip = ( + reservoir_precip_rate_over_res_interval * reservoir_timestep * m_per_mm + ) + total_precip = physics_precip_total_over_model_timestep + reservoir_total_precip + total_precip.attrs["units"] = "m" + return total_precip + + class TimeAverageInputs: """ Copy of time averaging components from runtime.diagnostics.manager to @@ -170,6 +240,7 @@ def __init__( warm_start: bool = False, hydrostatic: bool = False, mse_conserving_limiter: bool = False, + tendency_precip_tracker: Optional[TendencyPrecipTracker] = None, ): self.model = model self.synchronize_steps = synchronize_steps @@ -181,6 +252,7 @@ def __init__( self.warm_start = warm_start self.hydrostatic = hydrostatic self.mse_conserving_limiter = mse_conserving_limiter + self.tendency_precip_tracker = tendency_precip_tracker if state_machine is None: state_machine = _FiniteStateMachine() @@ -313,6 +385,7 @@ def predict(self, inputs, state): self._state_machine(self._state_machine.PREDICT) result = self.model.predict(inputs) + output_state = rename_dataset_members(result, self.rename_mapping) diags = rename_dataset_members( @@ -360,12 +433,19 @@ def __call__(self, time, state): if self.input_averager is not None: self.input_averager.increment_running_average(inputs) + if self.tendency_precip_tracker is not None: + self.tendency_precip_tracker.increment_physics_precip_rate( + state[PHYSICS_PRECIP_RATE] + ) + + tendencies, diags, updated_state = {}, {}, {} + if self._is_rc_update_step(time): logger.info(f"Reservoir model predict at time {time}") if self.input_averager is not None: inputs.update(self.input_averager.get_averages()) - tendencies, diags, updated_state = self.predict(inputs, state) + _, diags, updated_state = self.predict(inputs, state) hybrid_diags = rename_dataset_members( inputs, {k: f"{self.rename_mapping.get(k, k)}_hyb_in" for k in inputs} @@ -375,14 +455,11 @@ def __call__(self, time, state): # This check is done on the _rc_out diags since those are always available. # This allows zero field diags to be returned on timesteps where the # reservoir is not updating the state. - diags_Tq_vars = {f"{v}_{self.DIAGS_OUTPUT_SUFFIX}" for v in [TEMP, SPHUM]} - - if diags_Tq_vars.issubset(list(diags.keys())): - # TODO: Currently the reservoir only predicts updated states and returns - # empty tendencies. If tendency predictions are implemented in the - # prognostic run, the limiter/conservation updates should be updated to - # take this option into account and use predicted tendencies directly. - tendencies_from_state_prediction = tendencies_from_state_updates( + # diags_Tq_vars = {f"{v}_{self.DIAGS_OUTPUT_SUFFIX}" for v in [TEMP, SPHUM]} + # if diags_Tq_vars.issubset(list(diags.keys())): + + if self.tendency_precip_tracker is not None: + tendencies_over_model_timestep = tendencies_from_state_updates( initial_state=state, updated_state=updated_state, dt=self.model_timestep, @@ -392,7 +469,7 @@ def __call__(self, time, state): diagnostics_updates_from_constraints, ) = enforce_heating_and_moistening_tendency_constraints( state=state, - tendency=tendencies_from_state_prediction, + tendency=tendencies_over_model_timestep, timestep=self.model_timestep, mse_conserving=self.mse_conserving_limiter, hydrostatic=self.hydrostatic, @@ -401,19 +478,63 @@ def __call__(self, time, state): zero_fill_missing_tendencies=True, ) + # net moistening from reservoir update is calculated using the + # difference from the last model timestep, but is interpreted + # as an update over the reservoir timestep + # Tendencies over model timesteps are popped- they are only + # used in the limiter and constraint adjustments + _, net_moistening_due_to_reservoir = self.get_diagnostics( + state, + { + "dQ1": tendency_updates_from_constraints.pop("dQ1"), + "dQ2": tendency_updates_from_constraints.pop("dQ2"), + }, + ) + net_moistening_res = net_moistening_due_to_reservoir * ( + self.model_timestep / self.timestep.total_seconds() + ) + diags.update( + {"net_moistening_due_to_reservoir_adjustment": net_moistening_res} + ) diags.update(diagnostics_updates_from_constraints) + updated_state = add_tendency( state=state, tendencies=tendency_updates_from_constraints, dt=self.model_timestep, ) - tendencies.update(tendency_updates_from_constraints) - else: - tendencies, diags, updated_state = {}, {}, {} + tendencies = self.tendency_precip_tracker.calculate_tendencies( + updated_state.get(TEMP, state[TEMP]), + updated_state.get(SPHUM, state[SPHUM]), + ) + + self.tendency_precip_tracker.update_tracked_state( + updated_state.get(TEMP, state[TEMP]), + updated_state.get(SPHUM, state[SPHUM]), + ) + diags.update(tendencies) return tendencies, diags, updated_state + def update_precip( + self, physics_precip, net_moistening_due_to_reservoir, + ): + diags = {} + + # running average gets reset in this call + precip_rates = self.tendency_precip_tracker.interval_avg_precip_rates( + net_moistening_due_to_reservoir + ) + diags.update(precip_rates) + + diags[TOTAL_PRECIP] = self.tendency_precip_tracker.accumulated_precip_update( + physics_precip, + diags["reservoir_precip_rate_res_interval_avg"], + self.timestep.total_seconds(), + ) + return diags + def get_diagnostics(self, state, tendency): diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic) return diags, diags[f"net_moistening_due_to_{self.label}"] @@ -463,6 +584,12 @@ def get_reservoir_steppers( model, config.time_average_inputs ) + _precip_tracker_kwargs = {} + if config.interval_average_precipitation: + _precip_tracker_kwargs["tendency_precip_tracker"] = TendencyPrecipTracker( + reservoir_timestep_seconds=rc_tdelta.total_seconds(), + ) + incrementer = ReservoirIncrementOnlyStepper( model, init_time, @@ -487,5 +614,6 @@ def get_reservoir_steppers( model_timestep=model_timestep, hydrostatic=config.hydrostatic, mse_conserving_limiter=config.mse_conserving_limiter, + **_precip_tracker_kwargs, ) return incrementer, predictor diff --git a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out index 65af8316d4..a705c72c36 100644 --- a/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out +++ b/workflows/prognostic_c48_run/tests/_regtest_outputs/test_prepare_config.test_prepare_ml_config_regression[reservoir].out @@ -436,6 +436,7 @@ radiation_scheme: null reservoir_corrector: diagnostic_only: false hydrostatic: false + interval_average_precipitation: false models: 0: gs://vcm-ml-scratch/rc-model-tile-0 1: gs://vcm-ml-scratch/rc-model-tile-1