diff --git a/examples/standalone/runfile/dynamics.py b/examples/standalone/runfile/dynamics.py index 88326001..103161d1 100755 --- a/examples/standalone/runfile/dynamics.py +++ b/examples/standalone/runfile/dynamics.py @@ -263,6 +263,7 @@ def setup_dycore( config=dycore_config, phis=state.phis, state=state, + exclude_tracers=[], timestep=timedelta(seconds=dycore_config.dt_atmos), ) return dycore, state, stencil_factory diff --git a/pyFV3/dycore_state.py b/pyFV3/dycore_state.py index 9058e4d5..742fd55c 100644 --- a/pyFV3/dycore_state.py +++ b/pyFV3/dycore_state.py @@ -1,10 +1,10 @@ from dataclasses import asdict, dataclass, field, fields -from typing import Any, Dict, Mapping, Union +from typing import Any, Dict, List, Mapping, Union import xarray as xr import ndsl.dsl.gt4py_utils as gt_utils -from ndsl import GridSizer, Quantity, QuantityFactory +from ndsl import Quantity, QuantityFactory from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -16,6 +16,65 @@ from ndsl.dsl.typing import Float from ndsl.restart._legacy_restart import open_restart from ndsl.typing import Communicator +from pyFV3.tracers import Tracers + + +DEFAULT_TRACER_PROPERTIES = { + "specific_humidity": { + "pyFV3_key": "vapor", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "sphum", + "units": "g/kg", + }, + "cloud_liquid_water_mixing_ratio": { + "pyFV3_key": "liquid", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "liq_wat", + "units": "g/kg", + }, + "cloud_ice_mixing_ratio": { + "pyFV3_key": "ice", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "ice_wat", + "units": "g/kg", + }, + "rain_mixing_ratio": { + "pyFV3_key": "rain", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "rainwat", + "units": "g/kg", + }, + "snow_mixing_ratio": { + "pyFV3_key": "snow", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "snowwat", + "units": "g/kg", + }, + "graupel_mixing_ratio": { + "pyFV3_key": "graupel", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "graupel", + "units": "g/kg", + }, + "ozone_mixing_ratio": { + "pyFV3_key": "o3mr", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "o3mr", + "units": "g/kg", + }, + "turbulent_kinetic_energy": { + "pyFV3_key": "sgs_tke", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "sgs_tke", + "units": "g/kg", + }, + "cloud_fraction": { + "pyFV3_key": "cloud", + "dims": [Z_DIM, Y_DIM, X_DIM], + "restart_name": "cld_amt", + "units": "g/kg", + }, +} @dataclass() @@ -148,74 +207,10 @@ class DycoreState: "intent": "inout", } ) - qvapor: Quantity = field( - metadata={ - "name": "specific_humidity", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - } - ) - qliquid: Quantity = field( - metadata={ - "name": "cloud_water_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qice: Quantity = field( + tracers: Tracers = field( metadata={ - "name": "cloud_ice_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qrain: Quantity = field( - metadata={ - "name": "rain_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qsnow: Quantity = field( - metadata={ - "name": "snow_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qgraupel: Quantity = field( - metadata={ - "name": "graupel_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qo3mr: Quantity = field( - metadata={ - "name": "ozone_mixing_ratio", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "kg/kg", - "intent": "inout", - } - ) - qsgs_tke: Quantity = field( - metadata={ - "name": "turbulent_kinetic_energy", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "m**2/s**2", - "intent": "inout", - } - ) - qcld: Quantity = field( - metadata={ - "name": "cloud_fraction", - "dims": [X_DIM, Y_DIM, Z_DIM], - "units": "", + "name": "tracers", + "units": "g/kg", "intent": "inout", } ) @@ -297,6 +292,8 @@ class DycoreState: def __post_init__(self): for _field in fields(self): + if _field.name == "tracers": + continue for check_name in ["units", "dims"]: if check_name in _field.metadata: required = _field.metadata[check_name] @@ -310,7 +307,7 @@ def __post_init__(self): ) @classmethod - def init_zeros(cls, quantity_factory: QuantityFactory): + def init_zeros(cls, quantity_factory: QuantityFactory, tracer_list: List[str]): initial_storages = {} for _field in fields(cls): if "dims" in _field.metadata.keys(): @@ -319,13 +316,25 @@ def init_zeros(cls, quantity_factory: QuantityFactory): _field.metadata["units"], dtype=Float, ).data + for name in tracer_list: + initial_storages[name] = quantity_factory.zeros( + Tracers.dims, + Tracers.unit, + dtype=Float, + ).data return cls.init_from_storages( - storages=initial_storages, sizer=quantity_factory.sizer + storages=initial_storages, + quantity_factory=quantity_factory, + tracer_list=tracer_list, ) @classmethod def init_from_numpy_arrays( - cls, dict_of_numpy_arrays, sizer: GridSizer, backend: str + cls, + dict_of_numpy_arrays, + quantity_factory: QuantityFactory, + backend: str, + tracer_list: List[str], ): field_names = [_field.name for _field in fields(cls)] for variable_name in dict_of_numpy_arrays.keys(): @@ -341,10 +350,22 @@ def init_from_numpy_arrays( dict_of_numpy_arrays[_field.name], dims, _field.metadata["units"], - origin=sizer.get_origin(dims), - extent=sizer.get_extent(dims), + origin=quantity_factory.sizer.get_origin(dims), + extent=quantity_factory.sizer.get_extent(dims), gt4py_backend=backend, ) + elif issubclass(_field.type, Tracers): + if len(dict_of_numpy_arrays[_field.name]) != len(tracer_list): + raise ValueError( + "[pyFV3] DycoreState init:" + f" tracer list size ({len(tracer_list)})" + " doesn't match the inputs size" + f" ({len(dict_of_numpy_arrays[_field.name])})" + ) + dict_state[_field.name] = Tracers.make( + quantity_factory=quantity_factory, + tracer_mapping=tracer_list, + ) state = cls(**dict_state) # type: ignore return state @@ -352,7 +373,8 @@ def init_from_numpy_arrays( def init_from_storages( cls, storages: Mapping[str, Any], - sizer: GridSizer, + quantity_factory: QuantityFactory, + tracer_list: List[str], bdt: float = 0.0, mdt: float = 0.0, ): @@ -364,10 +386,19 @@ def init_from_storages( storages[_field.name], dims, _field.metadata["units"], - origin=sizer.get_origin(dims), - extent=sizer.get_extent(dims), + origin=quantity_factory.sizer.get_origin(dims), + extent=quantity_factory.sizer.get_extent(dims), ) inputs[_field.name] = quantity + elif "tracers" == _field.name: + tracers = Tracers.make( + quantity_factory=quantity_factory, + tracer_mapping=tracer_list, + ) + for name in tracer_list: + tracers[name].data[:] = storages[name][:] + inputs[_field.name] = tracers + return cls(**inputs, bdt=bdt, mdt=mdt) @classmethod @@ -381,10 +412,14 @@ def from_fortran_restart( state_dict: Mapping[str, Quantity] = open_restart( dirname=path, communicator=communicator, - tracer_properties=TRACER_PROPERTIES, + tracer_properties=DEFAULT_TRACER_PROPERTIES, + ) + new = cls.init_zeros( + quantity_factory=quantity_factory, + tracer_list=[ + str(prop["pyFV3_key"]) for prop in DEFAULT_TRACER_PROPERTIES.values() + ], ) - - new = cls.init_zeros(quantity_factory=quantity_factory) new.pt.view[:] = new.pt.np.asarray( state_dict["air_temperature"].transpose(new.pt.dims).view[:] ) @@ -405,31 +440,35 @@ def from_fortran_restart( new.v.view[:] = new.v.np.asarray( state_dict["y_wind"].transpose(new.v.dims).view[:] ) - new.qvapor.view[:] = new.qvapor.np.asarray( - state_dict["specific_humidity"].transpose(new.qvapor.dims).view[:] + new.tracers["vapor"].view[:] = new.tracers["vapor"].np.asarray( + state_dict["specific_humidity"].transpose(new.tracers["vapor"].dims).view[:] ) - new.qliquid.view[:] = new.qliquid.np.asarray( + new.tracers["liquid"].view[:] = new.tracers["liquid"].np.asarray( state_dict["cloud_liquid_water_mixing_ratio"] - .transpose(new.qliquid.dims) + .transpose(new.tracers["liquid"].dims) .view[:] ) - new.qice.view[:] = new.qice.np.asarray( - state_dict["cloud_ice_mixing_ratio"].transpose(new.qice.dims).view[:] + new.tracers["ice"].view[:] = new.tracers["ice"].np.asarray( + state_dict["cloud_ice_mixing_ratio"] + .transpose(new.tracers["ice"].dims) + .view[:] ) - new.qrain.view[:] = new.qrain.np.asarray( - state_dict["rain_mixing_ratio"].transpose(new.qrain.dims).view[:] + new.tracers["rain"].view[:] = new.tracers["rain"].np.asarray( + state_dict["rain_mixing_ratio"].transpose(new.tracers["rain"].dims).view[:] ) - new.qsnow.view[:] = new.qsnow.np.asarray( - state_dict["snow_mixing_ratio"].transpose(new.qsnow.dims).view[:] + new.tracers["snow"].view[:] = new.tracers["snow"].np.asarray( + state_dict["snow_mixing_ratio"].transpose(new.tracers["snow"].dims).view[:] ) - new.qgraupel.view[:] = new.qgraupel.np.asarray( - state_dict["graupel_mixing_ratio"].transpose(new.qgraupel.dims).view[:] + new.tracers["graupel"].view[:] = new.tracers["graupel"].np.asarray( + state_dict["graupel_mixing_ratio"] + .transpose(new.tracers["graupel"].dims) + .view[:] ) - new.qo3mr.view[:] = new.qo3mr.np.asarray( - state_dict["ozone_mixing_ratio"].transpose(new.qo3mr.dims).view[:] + new.tracers["o3mr"].view[:] = new.tracers["o3mr"].np.asarray( + state_dict["ozone_mixing_ratio"].transpose(new.tracers["o3mr"].dims).view[:] ) - new.qcld.view[:] = new.qcld.np.asarray( - state_dict["cloud_fraction"].transpose(new.qcld.dims).view[:] + new.tracers["cloud"].view[:] = new.tracers["cld"].np.asarray( + state_dict["cloud_fraction"].transpose(new.tracers["cld"].dims).view[:] ) new.delz.view[:] = new.delz.np.asarray( state_dict["vertical_thickness_of_atmospheric_layer"] @@ -439,22 +478,34 @@ def from_fortran_restart( return new + def _xr_dataarray_from_quantity(self, name: str, metadata: Dict[str, Any], data): + dims = [f"{dim_name}_{name}" for dim_name in metadata["dims"]] + return xr.DataArray( + gt_utils.asarray(data), + dims=dims, + attrs={ + "long_name": metadata["name"], + "units": metadata.get("units", "unknown"), + }, + ) + @property def xr_dataset(self): data_vars = {} for name, field_info in self.__dataclass_fields__.items(): if issubclass(field_info.type, Quantity): - dims = [ - f"{dim_name}_{name}" for dim_name in field_info.metadata["dims"] - ] - data_vars[name] = xr.DataArray( - gt_utils.asarray(getattr(self, name).data), - dims=dims, - attrs={ - "long_name": field_info.metadata["name"], - "units": field_info.metadata.get("units", "unknown"), - }, + data_vars[name] = self._xr_dataarray_from_quantity( + name=name, + metadata=field_info.metadata, + data=getattr(self, name).data, ) + if isinstance(field_info.type, Tracers): + for tracer in getattr(self, name).values(): + data_vars[name] = self._xr_dataarray_from_quantity( + name=name, + metadata=field_info.metadata, + data=tracer, + ) return xr.Dataset(data_vars=data_vars) def __getitem__(self, item): @@ -465,52 +516,3 @@ def as_dict(self, quantity_only=True) -> Dict[str, Union[Quantity, int]]: return {k: v for k, v in asdict(self).items() if isinstance(v, Quantity)} else: return {k: v for k, v in asdict(self).items()} - - -TRACER_PROPERTIES = { - "specific_humidity": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "sphum", - "units": "g/kg", - }, - "cloud_liquid_water_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "liq_wat", - "units": "g/kg", - }, - "cloud_ice_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "ice_wat", - "units": "g/kg", - }, - "rain_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "rainwat", - "units": "g/kg", - }, - "snow_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "snowwat", - "units": "g/kg", - }, - "graupel_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "graupel", - "units": "g/kg", - }, - "ozone_mixing_ratio": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "o3mr", - "units": "g/kg", - }, - "turbulent_kinetic_energy": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "sgs_tke", - "units": "g/kg", - }, - "cloud_fraction": { - "dims": [Z_DIM, Y_DIM, X_DIM], - "restart_name": "cld_amt", - "units": "g/kg", - }, -} diff --git a/pyFV3/initialization/test_cases/initialize_baroclinic.py b/pyFV3/initialization/test_cases/initialize_baroclinic.py index 17196057..1a4a92eb 100644 --- a/pyFV3/initialization/test_cases/initialize_baroclinic.py +++ b/pyFV3/initialization/test_cases/initialize_baroclinic.py @@ -339,9 +339,11 @@ def init_baroclinic_state( ) state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, - sizer=quantity_factory.sizer, + quantity_factory=quantity_factory, backend=sample_quantity.metadata.gt4py_backend, + tracer_list=["vapor", "liquid", "rain", "snow", "ice", "graupel", "cloud"], ) + state.tracers["vapor"].view[:] = numpy_state.qvapor[slice_3d] comm.halo_update(state.phis, n_points=NHALO) diff --git a/pyFV3/initialization/test_cases/initialize_tc.py b/pyFV3/initialization/test_cases/initialize_tc.py index 38a4a46c..a402bd48 100644 --- a/pyFV3/initialization/test_cases/initialize_tc.py +++ b/pyFV3/initialization/test_cases/initialize_tc.py @@ -561,7 +561,6 @@ def init_tc_state( numpy_state.pkz[:] = pkz numpy_state.ps[:] = pe[:, :, -1] numpy_state.pt[:] = pt - numpy_state.qvapor[:] = qvapor numpy_state.u[:] = ud numpy_state.ua[:] = ua numpy_state.v[:] = vd @@ -569,8 +568,10 @@ def init_tc_state( numpy_state.w[:] = w state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, - sizer=quantity_factory.sizer, + quantity_factory=quantity_factory, backend=sample_quantity.metadata.gt4py_backend, + tracer_list=["vapor", "liquid", "rain", "snow", "ice", "graupel", "cloud"], ) + state.tracers["vapor"].view[:] = qvapor return state diff --git a/pyFV3/stencils/dyn_core.py b/pyFV3/stencils/dyn_core.py index 530d7f26..5de8a5c9 100644 --- a/pyFV3/stencils/dyn_core.py +++ b/pyFV3/stencils/dyn_core.py @@ -460,6 +460,7 @@ def __init__( quantity_factory=quantity_factory, grid_data=grid_data, grid_type=config.grid_type, + use_logp=config.use_logp, ) ) self._akap = Float(constants.KAPPA) diff --git a/pyFV3/stencils/fillz.py b/pyFV3/stencils/fillz.py index 5cd5c239..a0c0156f 100644 --- a/pyFV3/stencils/fillz.py +++ b/pyFV3/stencils/fillz.py @@ -1,15 +1,14 @@ -import typing -from typing import Dict +from typing import List, no_type_check from gt4py.cartesian.gtscript import BACKWARD, FORWARD, PARALLEL, computation, interval -import ndsl.dsl.gt4py_utils as utils -from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, IntFieldIJ +from pyFV3.tracers import Tracers -@typing.no_type_check +@no_type_check def fix_tracer( q: FloatField, dp: FloatField, @@ -117,19 +116,18 @@ def __init__( self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory, - nq: int, - tracers: Dict[str, Quantity], + exclude_tracers: List[str], ): orchestrate( obj=self, config=stencil_factory.config.dace_config, dace_compiletime_args=["tracers"], ) - self._nq = int(nq) self._fix_tracer_stencil = stencil_factory.from_dims_halo( fix_tracer, compute_dims=[X_DIM, Y_DIM, Z_DIM], ) + self._exclude_tracers = exclude_tracers # Setting initial value of upper_fix to zero is only needed for validation. # The values in the compute domain are set to zero in the stencil. @@ -145,25 +143,22 @@ def __init__( dtype=Float, ) - self._filtered_tracer_dict = { - name: tracers[name] for name in utils.tracer_variables[0 : self._nq] - } - def __call__( self, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Tracers, ): """ Args: dp2 (in): pressure thickness of atmospheric layer tracers (inout): tracers to fix negative masses in """ - for tracer_name in self._filtered_tracer_dict.keys(): - self._fix_tracer_stencil( - tracers[tracer_name], - dp2, - self._zfix, - self._sum0, - self._sum1, - ) + for name, tracer in tracers.items(): + if name not in self._exclude_tracers: + self._fix_tracer_stencil( + tracer, + dp2, + self._zfix, + self._sum0, + self._sum1, + ) diff --git a/pyFV3/stencils/fv_dynamics.py b/pyFV3/stencils/fv_dynamics.py index 50ffba27..9a71d3c7 100644 --- a/pyFV3/stencils/fv_dynamics.py +++ b/pyFV3/stencils/fv_dynamics.py @@ -1,10 +1,9 @@ from datetime import timedelta -from typing import Mapping, Optional +from typing import List, Mapping, Optional from dace.frontend.python.interface import nounroll as dace_no_unroll from gt4py.cartesian.gtscript import PARALLEL, computation, interval -import ndsl.dsl.gt4py_utils as utils import pyFV3.stencils.moist_cv as moist_cv from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater from ndsl.checkpointer import NullCheckpointer @@ -98,6 +97,7 @@ def __init__( config: DynamicalCoreConfig, phis: Quantity, state: DycoreState, + exclude_tracers: List[str], timestep: timedelta, checkpointer: Optional[Checkpointer] = None, ): @@ -111,6 +111,8 @@ def __init__( the namelist in the Fortran model phis: surface geopotential height state: model state + exclude_tracer: List of named tracer to be excluded from the Advection, + and Remapping schemes timestep: model timestep checkpointer: if given, used to perform operations on model data at specific points in model execution, such as testing against @@ -198,6 +200,28 @@ def __init__( f" nwat=={config.nwat} is not implemented." " Only nwat=6 has been implemented." ) + + # Implemented dynamics options require those tracers to be present at minima + # this is a more granular list than carried by the `nwat` single integer + # but cover the same topic + required_tracers = [ + "vapor", + "liquid", + "rain", + "snow", + "ice", + "graupel", + "cloud", + ] + if not all(n in state.tracers.names() for n in required_tracers): + raise NotImplementedError( + "Dynamical core (fv_dynamics):" + " missing required tracers. Dynamics requires:\n" + f" {required_tracers}\n" + "but only the following where given:\n" + f" {state.tracers.names()}" + ) + self.comm_rank = comm.rank self.grid_data = grid_data self.grid_indexing = grid_indexing @@ -213,10 +237,6 @@ def __init__( hord=config.hord_tr, ) - self.tracers = {} - for name in utils.tracer_variables[0:NQ]: - self.tracers[name] = state.__dict__[name] - temporaries = fvdyn_temporaries(quantity_factory) self._te_2d = temporaries["te_2d"] self._te0_2d = temporaries["te0_2d"] @@ -231,7 +251,8 @@ def __init__( tracer_transport, self.grid_data, comm, - self.tracers, + state.tracers, + exclude_tracers=exclude_tracers, ) self._ak = grid_data.ak self._bk = grid_data.bk @@ -312,9 +333,9 @@ def __init__( quantity_factory=quantity_factory, config=config.remapping, area_64=grid_data.area_64, - nq=NQ, pfull=self._pfull, - tracers=self.tracers, + tracers=state.tracers, + exclude_tracers=exclude_tracers, checkpointer=checkpointer, ) @@ -352,7 +373,7 @@ def _checkpoint_fvdynamics(self, state: DycoreState, tag: str): va=state.va, uc=state.uc, vc=state.vc, - qvapor=state.qvapor, + qvapor=state.tracers["vapor"], ) def _checkpoint_remapping_in( @@ -462,12 +483,12 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool): log_on_rank_0("FV Setup") self._fv_setup_stencil( - state.qvapor, - state.qliquid, - state.qrain, - state.qsnow, - state.qice, - state.qgraupel, + state.tracers["vapor"], + state.tracers["liquid"], + state.tracers["rain"], + state.tracers["snow"], + state.tracers["ice"], + state.tracers["graupel"], state.q_con, self._cvm, state.pkz, @@ -536,7 +557,7 @@ def _compute(self, state: DycoreState, timer: Timer): with timer.clock("TracerAdvection"): self._checkpoint_tracer_advection_in(state) self.tracer_advection( - self.tracers, + state.tracers, self._dp_initial, state.mfxd, state.mfyd, @@ -573,7 +594,7 @@ def _compute(self, state: DycoreState, timer: Timer): # time # When NQ=8, we do need qcld passed explicitely self._lagrangian_to_eulerian_obj( - self.tracers, + state.tracers, state.pt, state.delp, state.delz, @@ -583,7 +604,6 @@ def _compute(self, state: DycoreState, timer: Timer): state.w, self._cappa, state.q_con, - state.qcld, state.pkz, state.pk, state.pe, @@ -625,13 +645,13 @@ def _compute(self, state: DycoreState, timer: Timer): if __debug__: log_on_rank_0("Neg Adj 3") self._adjust_tracer_mixing_ratio( - state.qvapor, - state.qliquid, - state.qrain, - state.qsnow, - state.qice, - state.qgraupel, - state.qcld, + state.tracers["vapor"], + state.tracers["liquid"], + state.tracers["rain"], + state.tracers["snow"], + state.tracers["ice"], + state.tracers["graupel"], + state.tracers["cloud"], state.pt, state.delp, ) diff --git a/pyFV3/stencils/mapn_tracer.py b/pyFV3/stencils/mapn_tracer.py index 0696d145..2b6911a9 100644 --- a/pyFV3/stencils/mapn_tracer.py +++ b/pyFV3/stencils/mapn_tracer.py @@ -1,11 +1,11 @@ -from typing import Dict +from typing import List -import ndsl.dsl.gt4py_utils as utils -from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float, FloatField from pyFV3.stencils.fillz import FillNegativeTracerValues from pyFV3.stencils.map_single import MapSingle +from pyFV3.tracers import Tracers class MapNTracer: @@ -18,43 +18,42 @@ def __init__( stencil_factory: StencilFactory, quantity_factory: QuantityFactory, kord: int, - nq: int, fill: bool, - tracers: Dict[str, Quantity], + tracers: Tracers, + exclude_tracers: List[str], ): orchestrate( obj=self, config=stencil_factory.config.dace_config, dace_compiletime_args=["tracers"], ) - self._nq = int(nq) + self._exclude_tracers = exclude_tracers self._qs = quantity_factory.zeros( [X_DIM, Y_DIM, Z_DIM], units="unknown", dtype=Float, ) - kord_tracer = [kord] * self._nq - kord_tracer[5] = 9 # qcld - - self._list_of_remap_objects = [ - MapSingle( + self._map_single = {} + for name in tracers.names(): + if name == "cloud": + this_kord = 9 + else: + this_kord = kord + self._map_single[name] = MapSingle( stencil_factory, quantity_factory, - kord_tracer[i], + this_kord, 0, dims=[X_DIM, Y_DIM, Z_DIM], ) - for i in range(len(kord_tracer)) - ] if fill: self._fill_negative_tracers = True self._fillz = FillNegativeTracerValues( stencil_factory, quantity_factory, - self._nq, - tracers, + exclude_tracers=self._exclude_tracers, ) else: self._fill_negative_tracers = False @@ -64,7 +63,7 @@ def __call__( pe1: FloatField, pe2: FloatField, dp2: FloatField, - tracers: Dict[str, Quantity], + tracers: Tracers, ): """ Remaps the tracer species onto the Eulerian grid @@ -77,8 +76,9 @@ def __call__( dp2 (in): Difference in pressure between Eulerian levels tracers (inout): tracers to be remapped """ - for i, q in enumerate(utils.tracer_variables[0 : self._nq]): - self._list_of_remap_objects[i](tracers[q], pe1, pe2, self._qs) + for name in tracers.names(): + if name not in self._exclude_tracers: + self._map_single[name](tracers[name], pe1, pe2, self._qs) if self._fill_negative_tracers is True: self._fillz(dp2, tracers) diff --git a/pyFV3/stencils/remapping.py b/pyFV3/stencils/remapping.py index a14c38a1..b81adcc6 100644 --- a/pyFV3/stencils/remapping.py +++ b/pyFV3/stencils/remapping.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import List, Optional from gt4py.cartesian.gtscript import ( __INLINED, @@ -13,7 +13,7 @@ region, ) -from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate +from ndsl import QuantityFactory, StencilFactory, orchestrate from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -31,6 +31,7 @@ from pyFV3.stencils.mapn_tracer import MapNTracer from pyFV3.stencils.moist_cv import moist_pt_func, moist_pt_last_step from pyFV3.stencils.saturation_adjustment import SatAdjust3d +from pyFV3.tracers import Tracers # TODO: Should this be set here or in global_constants? @@ -292,9 +293,9 @@ def __init__( quantity_factory: QuantityFactory, config: RemappingConfig, area_64, - nq, pfull, - tracers: Dict[str, Quantity], + tracers: Tracers, + exclude_tracers: List[str], checkpointer: Optional[Checkpointer] = None, ): orchestrate( @@ -314,7 +315,6 @@ def __init__( raise NotImplementedError("Hydrostatic is not implemented") self._t_min = 184.0 - self._nq = nq # do_omega = hydrostatic and last_step # TODO pull into inputs self._domain_jextra = ( grid_indexing.domain[0], @@ -410,9 +410,9 @@ def __init__( stencil_factory, quantity_factory, abs(config.kord_tr), - nq, fill=config.fill, tracers=tracers, + exclude_tracers=exclude_tracers, ) self._map_single_w = MapSingle( @@ -518,7 +518,7 @@ def __init__( def __call__( self, - tracers: Dict[str, Quantity], + tracers: Tracers, pt: FloatField, delp: FloatField, delz: FloatField, @@ -528,7 +528,6 @@ def __call__( w: FloatField, cappa: FloatField, q_con: FloatField, - q_cld: FloatField, pkz: FloatField, pk: FloatField, pe: FloatField, @@ -562,7 +561,6 @@ def __call__( va (inout): A-grid y-velocity cappa (inout): Power to raise pressure to q_con (out): Total condensate mixing ratio - q_cld (out): Cloud fraction pkz (in): Layer mean pressure raised to the power of Kappa pk (out): Interface pressure raised to power of kappa, final acoustic value pe (in): Pressure at layer edges @@ -593,12 +591,12 @@ def __call__( # pe2 is final Eulerian edge pressures self._moist_cv_pt_pressure( - tracers["qvapor"], - tracers["qliquid"], - tracers["qrain"], - tracers["qsnow"], - tracers["qice"], - tracers["qgraupel"], + tracers["vapor"], + tracers["liquid"], + tracers["rain"], + tracers["snow"], + tracers["ice"], + tracers["graupel"], q_con, pt, cappa, @@ -633,12 +631,12 @@ def __call__( # it clear the outputs are not needed until then? # or, are its outputs actually used? can we delete this stencil call? self._moist_cv_pkz( - tracers["qvapor"], - tracers["qliquid"], - tracers["qrain"], - tracers["qsnow"], - tracers["qice"], - tracers["qgraupel"], + tracers["vapor"], + tracers["liquid"], + tracers["rain"], + tracers["snow"], + tracers["ice"], + tracers["graupel"], q_con, self._gz, self._cvm, @@ -683,13 +681,13 @@ def __call__( fast_mp_consv = consv_te > CONSV_MIN self._saturation_adjustment( dp1, - tracers["qvapor"], - tracers["qliquid"], - tracers["qice"], - tracers["qrain"], - tracers["qsnow"], - tracers["qgraupel"], - q_cld, + tracers["vapor"], + tracers["liquid"], + tracers["ice"], + tracers["rain"], + tracers["snow"], + tracers["graupel"], + tracers["cloud"], hs, peln, delp, @@ -711,12 +709,12 @@ def __call__( # to the physics, but if we're staying in dynamics we need # to keep it as the virtual potential temperature self._moist_cv_last_step_stencil( - tracers["qvapor"], - tracers["qliquid"], - tracers["qrain"], - tracers["qsnow"], - tracers["qice"], - tracers["qgraupel"], + tracers["vapor"], + tracers["liquid"], + tracers["rain"], + tracers["snow"], + tracers["ice"], + tracers["graupel"], self._gz, pt, pkz, diff --git a/pyFV3/stencils/tracer_2d_1l.py b/pyFV3/stencils/tracer_2d_1l.py index 27514a37..d2fba074 100644 --- a/pyFV3/stencils/tracer_2d_1l.py +++ b/pyFV3/stencils/tracer_2d_1l.py @@ -1,16 +1,10 @@ import math -from typing import Dict +from typing import List import gt4py.cartesian.gtscript as gtscript from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region -from ndsl import ( - Quantity, - QuantityFactory, - StencilFactory, - WrappedHaloUpdater, - orchestrate, -) +from ndsl import QuantityFactory, StencilFactory, WrappedHaloUpdater, orchestrate from ndsl.constants import ( N_HALO_DEFAULT, X_DIM, @@ -22,6 +16,7 @@ from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.typing import Communicator from pyFV3.stencils.fvtp2d import FiniteVolumeTransport +from pyFV3.tracers import Tracers @gtscript.function @@ -192,7 +187,8 @@ def __init__( transport: FiniteVolumeTransport, grid_data, comm: Communicator, - tracers: Dict[str, Quantity], + tracers: Tracers, + exclude_tracers: List[str], ): orchestrate( obj=self, @@ -201,8 +197,9 @@ def __init__( ) grid_indexing = stencil_factory.grid_indexing self.grid_indexing = grid_indexing # needed for selective validation - self._tracer_count = len(tracers) + self._tracer_count = tracers.count self.grid_data = grid_data + self._exclude_tracers = exclude_tracers self._x_area_flux = quantity_factory.zeros( [X_INTERFACE_DIM, Y_DIM, Z_DIM], @@ -282,15 +279,23 @@ def __init__( n_halo=N_HALO_DEFAULT, dtype=Float, ) + + # We can exclude tracers from advecting and therefore also + # halo exchanging + advected_tracers = {} + for name, tracer in tracers.items(): + if name in exclude_tracers: + continue + advected_tracers[name] = tracer self._tracers_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([tracer_halo_spec] * self._tracer_count), - tracers, - [t for t in tracers.keys()], + comm.get_scalar_halo_updater([tracer_halo_spec] * len(advected_tracers)), + advected_tracers, + [t for t in advected_tracers.keys()], ) def __call__( self, - tracers: Dict[str, Quantity], + tracers: Tracers, dp1, x_mass_flux, y_mass_flux, @@ -392,26 +397,29 @@ def __call__( self.grid_data.rarea, dp2, ) - for q in tracers.values(): - self.finite_volume_transport( - q, - x_courant, - y_courant, - self._x_area_flux, - self._y_area_flux, - self._x_flux, - self._y_flux, - x_mass_flux=x_mass_flux, - y_mass_flux=y_mass_flux, - ) - self._apply_tracer_flux( - q, - dp1, - self._x_flux, - self._y_flux, - self.grid_data.rarea, - dp2, - ) + for name, q in tracers.items(): + if name in self._exclude_tracers: + pass + else: + self.finite_volume_transport( + q, + x_courant, + y_courant, + self._x_area_flux, + self._y_area_flux, + self._x_flux, + self._y_flux, + x_mass_flux=x_mass_flux, + y_mass_flux=y_mass_flux, + ) + self._apply_tracer_flux( + q, + dp1, + self._x_flux, + self._y_flux, + self.grid_data.rarea, + dp2, + ) if not last_call: self._tracers_halo_updater.update() # we can't use variable assignment to avoid a data copy diff --git a/pyFV3/testing/translate_dyncore.py b/pyFV3/testing/translate_dyncore.py index 8524d090..111d6844 100644 --- a/pyFV3/testing/translate_dyncore.py +++ b/pyFV3/testing/translate_dyncore.py @@ -140,7 +140,10 @@ def compute_parallel(self, inputs, communicator): grid_data.bk = inputs["bk"] grid_data.ptop = inputs["ptop"] self._base.make_storage_data_input_vars(inputs) - state = DycoreState.init_zeros(quantity_factory=self.grid.quantity_factory) + state = DycoreState.init_zeros( + quantity_factory=self.grid.quantity_factory, + tracer_list=[], # No tracers used in acoustics + ) wsd: Quantity = self.grid.quantity_factory.zeros( dims=[X_DIM, Y_DIM], units="unknown", diff --git a/pyFV3/testing/translate_fvdynamics.py b/pyFV3/testing/translate_fvdynamics.py index 5daea9d0..a4178640 100644 --- a/pyFV3/testing/translate_fvdynamics.py +++ b/pyFV3/testing/translate_fvdynamics.py @@ -5,7 +5,7 @@ import pytest import ndsl.dsl.gt4py_utils as utils -from ndsl import Namelist, Quantity, StencilFactory +from ndsl import Namelist, Quantity, QuantityFactory, StencilFactory from ndsl.constants import ( X_DIM, X_INTERFACE_DIM, @@ -33,6 +33,31 @@ def __init__( self.namelist = DynamicalCoreConfig.from_namelist(namelist) +TRACERS_IN_PYFV3 = [ + "vapor", + "liquid", + "ice", + "rain", + "snow", + "graupel", + "o3mr", + "sgs_tke", + "cloud", +] + +TRACERS_IN_FORTRAN = [ + "qvapor", + "qliquid", + "qice", + "qrain", + "qsnow", + "qgraupel", + "qo3mr", + "qsgs_tke", + "qcld", +] + + class TranslateFVDynamics(ParallelTranslateBaseSlicing): compute_grid_option = True inputs: Dict[str, Any] = { @@ -290,25 +315,50 @@ def __init__( self.max_error = 1e-5 self.ignore_near_zero_errors = {} - for qvar in utils.tracer_variables: - self.ignore_near_zero_errors[qvar] = True + self.ignore_near_zero_errors["qvapor"] = True + self.ignore_near_zero_errors["qliquid"] = True + self.ignore_near_zero_errors["qice"] = True + self.ignore_near_zero_errors["qrain"] = True + self.ignore_near_zero_errors["qsnow"] = True + self.ignore_near_zero_errors["qgraupel"] = True + self.ignore_near_zero_errors["qo3mr"] = True + self.ignore_near_zero_errors["qsgs_tke"] = True + self.ignore_near_zero_errors["qcld"] = True self.ignore_near_zero_errors["q_con"] = True self.dycore: Optional[fv_dynamics.DynamicalCore] = None self.stencil_factory = stencil_factory + self._quantity_factory = QuantityFactory.from_backend( + sizer=stencil_factory.grid_indexing._sizer, + backend=stencil_factory.backend, + ) self.namelist: DynamicalCoreConfig = DynamicalCoreConfig.from_namelist(namelist) def state_from_inputs(self, inputs): input_storages = super().state_from_inputs(inputs) + # extract tracers + input_storages["vapor"] = input_storages.pop("qvapor") + input_storages["liquid"] = input_storages.pop("qliquid") + input_storages["ice"] = input_storages.pop("qice") + input_storages["rain"] = input_storages.pop("qrain") + input_storages["snow"] = input_storages.pop("qsnow") + input_storages["graupel"] = input_storages.pop("qgraupel") + input_storages["o3mr"] = input_storages.pop("qo3mr") + input_storages["sgs_tke"] = input_storages.pop("qsgs_tke") + input_storages["cloud"] = input_storages.pop("qcld") # making sure we init DycoreState with the exact set of variables accepted_keys = [_field.name for _field in fields(DycoreState)] + accepted_keys += TRACERS_IN_PYFV3 todelete = [] for name in input_storages.keys(): if name not in accepted_keys: todelete.append(name) for name in todelete: del input_storages[name] - - state = DycoreState.init_from_storages(input_storages, sizer=self.grid.sizer) + state = DycoreState.init_from_storages( + storages=input_storages, + quantity_factory=self._quantity_factory, + tracer_list=TRACERS_IN_PYFV3, + ) return state def prepare_data(self, inputs) -> Tuple[DycoreState, GridData]: @@ -340,6 +390,7 @@ def compute_parallel(self, inputs, communicator): config=DynamicalCoreConfig.from_namelist(self.namelist), phis=state.phis, state=state, + exclude_tracers=["cloud"], timestep=timedelta(seconds=inputs["bdt"]), ) self.dycore.step_dynamics(state, NullTimer()) @@ -352,7 +403,10 @@ def outputs_from_state(self, state: dict): outputs = {} storages = {} for name, properties in self.outputs.items(): - if isinstance(state[name], Quantity): + if name in TRACERS_IN_FORTRAN: + idx = TRACERS_IN_FORTRAN.index(name) + storages[name] = state["tracers"][TRACERS_IN_PYFV3[idx]].data + elif isinstance(state[name], Quantity): storages[name] = state[name].data elif len(self.outputs[name]["dims"]) > 0: storages[name] = state[name] # assume it's a storage diff --git a/pyFV3/tracers.py b/pyFV3/tracers.py new file mode 100644 index 00000000..0407c5e6 --- /dev/null +++ b/pyFV3/tracers.py @@ -0,0 +1,61 @@ +from typing import TypeAlias +from ndsl import QuantityFactory +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.quantity.field_bundle import FieldBundle, FieldBundleType +from pyFV3.version import IS_GEOS + +# Defauult maopping for common models +_default_mapping_GEOS = { + "vapor": 0, + "liquid": 1, + "ice": 2, + "rain": 3, + "snow": 4, + "graupel": 5, + "cloud": 6, +} +_default_mapping_PACE = { + "vapor": 0, + "liquid": 1, + "rain": 2, + "ice": 3, + "snow": 4, + "graupel": 5, + "om3r": 6, + "cloud": 7, +} + + +TracersType: TypeAlias = FieldBundleType.T("Tracers") # type: ignore + + +def setup_tracers( + number_of_tracers: int, + quantity_factory: QuantityFactory, + mappings: dict[str, int] | None = None, +) -> FieldBundle: + """Setup a FieldBundle for tracers. Should be called only once.""" + + FieldBundleType.register("Tracers", (number_of_tracers,)) + + _unit = "g/kg" + _dims = [X_DIM, Y_DIM, Z_DIM, "tracers"] + + tracers_qty_factory = FieldBundle.extend_3D_quantity_factory( + quantity_factory, {"tracers": number_of_tracers} + ) + data = tracers_qty_factory.zeros(_dims, units=_unit) + + # Some default mappings for ease of use with commonly + # run models + if mappings is None: + if IS_GEOS: + mappings = _default_mapping_GEOS + else: + mappings = _default_mapping_PACE + + return FieldBundle( + "Tracers", + quantity=data, + mapping=mappings, + ) diff --git a/pyFV3/wrappers/geos_wrapper.py b/pyFV3/wrappers/geos_wrapper.py index 98feb79c..53ad91c5 100644 --- a/pyFV3/wrappers/geos_wrapper.py +++ b/pyFV3/wrappers/geos_wrapper.py @@ -10,6 +10,7 @@ from mpi4py import MPI import pyFV3 +import pyFV3.tracers from ndsl import ( CompilationConfig, CubedSphereCommunicator, @@ -36,6 +37,17 @@ from ndsl.utils import safe_assign_array +GEOS_TRACER_MAPPING = [ + "vapor", + "liquid", + "ice", + "rain", + "snow", + "graupel", + "cloud", +] + + class StencilBackendCompilerOverride: """Override the Pace global stencil JIT to allow for 9-rank build on any setup. @@ -104,8 +116,23 @@ def __init__( bdt: int, comm: Comm, backend: str, + water_tracers_count: int, + all_tracers_count: int, fortran_mem_space: MemorySpace = MemorySpace.HOST, ): + # Check for water species configuration not handled by the interface + if water_tracers_count != 6: + raise NotImplementedError( + "[pyFV3 Bridge] Bridge expect 6 water species," + f" got {water_tracers_count}." + ) + + # Build the full tracer mapping by appending None to the expected tracer list + # based on parameter + self._tracers_mapping = GEOS_TRACER_MAPPING + for i in range(all_tracers_count, len(GEOS_TRACER_MAPPING)): + self._tracers_mapping.append(f"tracer_#{i}") + # Look for an override to run on a single node gtfv3_single_rank_override = int(os.getenv("GTFV3_SINGLE_RANK_OVERRIDE", -1)) if gtfv3_single_rank_override >= 0: @@ -137,7 +164,7 @@ def __init__( metric_terms = MetricTerms( quantity_factory=quantity_factory, communicator=self.communicator, - eta_file=namelist["grid_config"]["config"]["eta_file"], + eta_file=namelist["grid_config"]["config"]["eta_file"], # type: ignore ) grid_data = GridData.new_from_metric_terms(metric_terms) @@ -173,7 +200,8 @@ def __init__( ) self.dycore_state = pyFV3.DycoreState.init_zeros( - quantity_factory=quantity_factory + quantity_factory=quantity_factory, + tracer_list=self._tracers_mapping, ) self.dycore_state.bdt = self.dycore_config.dt_atmos @@ -190,6 +218,7 @@ def __init__( timestep=timedelta(seconds=self.dycore_state.bdt), phis=self.dycore_state.phis, state=self.dycore_state, + exclude_tracers=[], ) self._fortran_mem_space = fortran_mem_space @@ -198,7 +227,6 @@ def __init__( ) self.output_dict: Dict[str, np.ndarray] = {} - self._allocate_output_dir() # Feedback information device_ordinal_info = ( @@ -368,15 +396,11 @@ def _put_fortran_data_in_dycore( safe_assign_array(state.omga.view[:], omga[isc:iec, jsc:jec, :]) safe_assign_array(state.diss_estd.view[:], diss_estd[isc:iec, jsc:jec, :]) - # tracer quantities should be a 4d array in order: - # vapor, liquid, ice, rain, snow, graupel, cloud - safe_assign_array(state.qvapor.view[:], q[isc:iec, jsc:jec, :, 0]) - safe_assign_array(state.qliquid.view[:], q[isc:iec, jsc:jec, :, 1]) - safe_assign_array(state.qice.view[:], q[isc:iec, jsc:jec, :, 2]) - safe_assign_array(state.qrain.view[:], q[isc:iec, jsc:jec, :, 3]) - safe_assign_array(state.qsnow.view[:], q[isc:iec, jsc:jec, :, 4]) - safe_assign_array(state.qgraupel.view[:], q[isc:iec, jsc:jec, :, 5]) - safe_assign_array(state.qcld.view[:], q[isc:iec, jsc:jec, :, 6]) + # Copy tracer data + for index, name in enumerate(self._tracers_mapping): + safe_assign_array( + state.tracers[name].view[:], q[isc:iec, jsc:jec, :, index] + ) return state @@ -388,6 +412,7 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]: jec = self._grid_indexing.jec + 1 if self._fortran_mem_space != self._pace_mem_space: + self._allocate_output_dir() safe_assign_array(output_dict["u"], self.dycore_state.u.data[:-1, :, :-1]) safe_assign_array(output_dict["v"], self.dycore_state.v.data[:, :-1, :-1]) safe_assign_array(output_dict["w"], self.dycore_state.w.data[:-1, :-1, :-1]) @@ -453,27 +478,8 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]: self.dycore_state.diss_estd.data[:-1, :-1, :-1], ) - safe_assign_array( - output_dict["qvapor"], self.dycore_state.qvapor.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qliquid"], self.dycore_state.qliquid.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qice"], self.dycore_state.qice.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qrain"], self.dycore_state.qrain.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qsnow"], self.dycore_state.qsnow.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qgraupel"], self.dycore_state.qgraupel.data[:-1, :-1, :-1] - ) - safe_assign_array( - output_dict["qcld"], self.dycore_state.qcld.data[:-1, :-1, :-1] - ) + # Copy tracer data + safe_assign_array(output_dict["q"], self.dycore_state.tracers.as_4D_array()) else: output_dict["u"] = self.dycore_state.u.data[:-1, :, :-1] output_dict["v"] = self.dycore_state.v.data[:, :-1, :-1] @@ -504,23 +510,18 @@ def _prep_outputs_for_geos(self) -> Dict[str, np.ndarray]: output_dict["q_con"] = self.dycore_state.q_con.data[:-1, :-1, :-1] output_dict["omga"] = self.dycore_state.omga.data[:-1, :-1, :-1] output_dict["diss_estd"] = self.dycore_state.diss_estd.data[:-1, :-1, :-1] - output_dict["qvapor"] = self.dycore_state.qvapor.data[:-1, :-1, :-1] - output_dict["qliquid"] = self.dycore_state.qliquid.data[:-1, :-1, :-1] - output_dict["qice"] = self.dycore_state.qice.data[:-1, :-1, :-1] - output_dict["qrain"] = self.dycore_state.qrain.data[:-1, :-1, :-1] - output_dict["qsnow"] = self.dycore_state.qsnow.data[:-1, :-1, :-1] - output_dict["qgraupel"] = self.dycore_state.qgraupel.data[:-1, :-1, :-1] - output_dict["qcld"] = self.dycore_state.qcld.data[:-1, :-1, :-1] + output_dict["q"] = self.dycore_state.tracers.as_4D_array() return output_dict def _allocate_output_dir(self): + if len(self.output_dict) != 0: + return if self._fortran_mem_space != self._pace_mem_space: nhalo = self._grid_indexing.n_halo shape_centered = self._grid_indexing.domain_full(add=(0, 0, 0)) shape_x_interface = self._grid_indexing.domain_full(add=(1, 0, 0)) shape_y_interface = self._grid_indexing.domain_full(add=(0, 1, 0)) - shape_z_interface = self._grid_indexing.domain_full(add=(0, 0, 1)) shape_2d = shape_centered[:-1] self.output_dict["u"] = np.empty((shape_y_interface)) @@ -573,34 +574,3 @@ def _allocate_output_dir(self): self.output_dict["qsnow"] = np.empty((shape_centered)) self.output_dict["qgraupel"] = np.empty((shape_centered)) self.output_dict["qcld"] = np.empty((shape_centered)) - else: - self.output_dict["u"] = None - self.output_dict["v"] = None - self.output_dict["w"] = None - self.output_dict["ua"] = None - self.output_dict["va"] = None - self.output_dict["uc"] = None - self.output_dict["vc"] = None - self.output_dict["delz"] = None - self.output_dict["pt"] = None - self.output_dict["delp"] = None - self.output_dict["mfxd"] = None - self.output_dict["mfyd"] = None - self.output_dict["cxd"] = None - self.output_dict["cyd"] = None - self.output_dict["ps"] = None - self.output_dict["pe"] = None - self.output_dict["pk"] = None - self.output_dict["peln"] = None - self.output_dict["pkz"] = None - self.output_dict["phis"] = None - self.output_dict["q_con"] = None - self.output_dict["omga"] = None - self.output_dict["diss_estd"] = None - self.output_dict["qvapor"] = None - self.output_dict["qliquid"] = None - self.output_dict["qice"] = None - self.output_dict["qrain"] = None - self.output_dict["qsnow"] = None - self.output_dict["qgraupel"] = None - self.output_dict["qcld"] = None diff --git a/tests/mpi/test_doubly_periodic.py b/tests/mpi/test_doubly_periodic.py index a264ea40..5b357bbf 100644 --- a/tests/mpi/test_doubly_periodic.py +++ b/tests/mpi/test_doubly_periodic.py @@ -128,6 +128,7 @@ def setup_dycore() -> Tuple[DynamicalCore, List[Any]]: config=config, phis=state.phis, state=state, + exclude_tracers=[], timestep=timedelta(seconds=255), ) # TODO compute from namelist diff --git a/tests/savepoint/translate/translate_fillz.py b/tests/savepoint/translate/translate_fillz.py index c08b5323..2e90ff8e 100644 --- a/tests/savepoint/translate/translate_fillz.py +++ b/tests/savepoint/translate/translate_fillz.py @@ -1,11 +1,13 @@ +from typing import List + import numpy as np -import ndsl.dsl.gt4py_utils as utils -from ndsl import Namelist, StencilFactory +from ndsl import Namelist, QuantityFactory, StencilFactory from ndsl.stencils.testing import pad_field_in_j from ndsl.utils import safe_assign_array -from pyFV3.stencils import fillz +from pyFV3.stencils.fillz import FillNegativeTracerValues from pyFV3.testing import TranslateDycoreFortranData2Py +from pyFV3.tracers import Tracers class TranslateFillz(TranslateDycoreFortranData2Py): @@ -33,18 +35,20 @@ def __init__( self.max_error = 1e-13 self.ignore_near_zero_errors = {"q2tracers": True} self.stencil_factory = stencil_factory + self._quantity_factory = QuantityFactory.from_backend( + sizer=stencil_factory.grid_indexing._sizer, + backend=stencil_factory.backend, + ) - def make_storage_data_input_vars(self, inputs, storage_vars=None): - if storage_vars is None: - storage_vars = self.storage_vars() + def make_storage_data_input_vars(self, inputs, tracer_mapping: List[str]): + storage_vars = self.storage_vars() info = storage_vars["dp2"] inputs["dp2"] = self.make_storage_data( np.squeeze(inputs["dp2"]), istart=info["istart"], axis=info["axis"] ) - inputs["tracers"] = {} info = storage_vars["q2tracers"] for i in range(int(inputs["nq"])): - inputs["tracers"][utils.tracer_variables[i]] = self.make_storage_data( + inputs["tracers"][tracer_mapping[i]] = self.make_storage_data( np.squeeze(inputs["q2tracers"][:, :, i]), istart=info["istart"], axis=info["axis"], @@ -52,7 +56,23 @@ def make_storage_data_input_vars(self, inputs, storage_vars=None): del inputs["q2tracers"] def compute(self, inputs): - self.make_storage_data_input_vars(inputs) + tracer_mapping = [ + "vapor", + "liquid", + "rain", + "ice", + "snow", + "graupel", + "o3mr", + "sgs_tke", + ] + tracers = Tracers.make( + quantity_factory=self._quantity_factory, + tracer_mapping=tracer_mapping, + ) + inputs["tracers"] = tracers + + self.make_storage_data_input_vars(inputs, tracer_mapping=tracer_mapping) for name, value in tuple(inputs.items()): if hasattr(value, "shape") and len(value.shape) > 1 and value.shape[1] == 1: inputs[name] = self.make_storage_data( @@ -67,18 +87,18 @@ def compute(self, inputs): value, self.grid.njd, backend=self.stencil_factory.backend ) ) - run_fillz = fillz.FillNegativeTracerValues( + inputs.pop("nq") + fillz = FillNegativeTracerValues( self.stencil_factory, self.grid.quantity_factory, - inputs.pop("nq"), - inputs["tracers"], + exclude_tracers=[], ) - run_fillz(**inputs) + fillz(**inputs) ds = self.grid.default_domain_dict() ds.update(self.out_vars["q2tracers"]) - tracers = np.zeros((self.grid.nic, self.grid.npz, len(inputs["tracers"]))) + tracers = np.zeros((self.grid.nic, self.grid.npz, inputs["tracers"].count)) for varname, data in inputs["tracers"].items(): - index = utils.tracer_variables.index(varname) + index = tracer_mapping.index(varname) data[self.grid.slice_dict(ds)] safe_assign_array( tracers[:, :, index], np.squeeze(data[self.grid.slice_dict(ds)]) diff --git a/tests/savepoint/translate/translate_init_case.py b/tests/savepoint/translate/translate_init_case.py index 9cd87f03..06d8fcc0 100644 --- a/tests/savepoint/translate/translate_init_case.py +++ b/tests/savepoint/translate/translate_init_case.py @@ -4,7 +4,6 @@ import pytest import ndsl.constants as constants -import ndsl.dsl.gt4py_utils as utils import pyFV3.initialization.analytic_init as analytic_init import pyFV3.initialization.init_utils as init_utils import pyFV3.initialization.test_cases.initialize_baroclinic as baroclinic_init @@ -20,7 +19,6 @@ ) from ndsl.grid import GridData, MetricTerms from ndsl.stencils.testing import ParallelTranslateBaseSlicing -from ndsl.stencils.testing.grid import TRACER_DIM # type: ignore from pyFV3.testing import TranslateDycoreFortranData2Py @@ -112,7 +110,7 @@ class TranslateInitCase(ParallelTranslateBaseSlicing): }, "q4d": { "name": "tracers", - "dims": [X_DIM, Y_DIM, Z_DIM, TRACER_DIM], + "dims": [X_DIM, Y_DIM, Z_DIM, "tracers"], "units": "kg/kg", }, } @@ -166,6 +164,10 @@ def __init__( self.ignore_near_zero_errors[var] = {"near_zero": 2e-13} self.namelist = namelist # type: ignore self.stencil_factory = stencil_factory + self._quantity_factory = QuantityFactory.from_backend( + sizer=stencil_factory.grid_indexing._sizer, + backend=stencil_factory.backend, + ) def compute_sequential(self, *args, **kwargs): pytest.skip( @@ -177,10 +179,8 @@ def outputs_from_state(self, state: dict): outputs = {} arrays = {} for name, properties in self.outputs.items(): - if isinstance(state[name], dict): - for tracer, quantity in state[name].items(): - state[name][tracer] = state[name][tracer].data - arrays[name] = state[name] + if name == "q4d": + arrays[name] = state["tracers"].as_4D_array() elif len(self.outputs[name]["dims"]) > 0: arrays[name] = state[name].data else: @@ -229,7 +229,6 @@ def compute_parallel(self, inputs, communicator): ) grid_data = GridData.new_from_metric_terms(metric_terms) - quantity_factory = QuantityFactory() state = analytic_init.init_analytic_state( analytic_init_case="baroclinic", @@ -241,9 +240,6 @@ def compute_parallel(self, inputs, communicator): comm=communicator, ) - state.q4d = {} - for tracer in utils.tracer_variables: - state.q4d[tracer] = getattr(state, tracer) return self.outputs_from_state(state.__dict__) diff --git a/tests/savepoint/translate/translate_remapping.py b/tests/savepoint/translate/translate_remapping.py index 43ddb27d..c644c6bb 100644 --- a/tests/savepoint/translate/translate_remapping.py +++ b/tests/savepoint/translate/translate_remapping.py @@ -1,9 +1,10 @@ import ndsl.dsl.gt4py_utils as utils -from ndsl import Namelist, StencilFactory +from ndsl import Namelist, QuantityFactory, StencilFactory from ndsl.constants import Z_DIM from pyFV3 import DynamicalCoreConfig from pyFV3.stencils import LagrangianToEulerian from pyFV3.testing import TranslateDycoreFortranData2Py +from pyFV3.tracers import Tracers class TranslateRemapping(TranslateDycoreFortranData2Py): @@ -97,6 +98,10 @@ def __init__( self.ignore_near_zero_errors = {"q_con": True, "tracers": True} self.stencil_factory = stencil_factory self.namelist = DynamicalCoreConfig.from_namelist(namelist) + self._quantity_factory = QuantityFactory.from_backend( + sizer=stencil_factory.grid_indexing._sizer, + backend=stencil_factory.backend, + ) def compute_from_storage(self, inputs): wsd_2d = utils.make_storage_from_shape( @@ -104,19 +109,35 @@ def compute_from_storage(self, inputs): ) wsd_2d[:, :] = inputs["wsd"][:, :, 0] inputs["wsd"] = wsd_2d - inputs["q_cld"] = inputs["tracers"]["qcld"] + tracers = Tracers.make_from_4D_array( + quantity_factory=self._quantity_factory, + tracer_mapping=[ + "vapor", + "liquid", + "rain", + "ice", + "snow", + "graupel", + "qo3mr", + "qsgs_tke", + "cloud", + ], + tracer_data=inputs["tracers"], + ) inputs["last_step"] = bool(inputs["last_step"]) pfull = self.grid.quantity_factory.zeros([Z_DIM], units="Pa") pfull.data[:] = pfull.np.asarray(inputs.pop("pfull")) + inputs.pop("nq") + inputs["tracers"] = tracers l_to_e_obj = LagrangianToEulerian( self.stencil_factory, quantity_factory=self.grid.quantity_factory, config=DynamicalCoreConfig.from_namelist(self.namelist).remapping, area_64=self.grid.area_64, - nq=inputs.pop("nq"), pfull=pfull, tracers=inputs["tracers"], + exclude_tracers=["cloud"], ) l_to_e_obj(**inputs) - inputs.pop("q_cld") + inputs["tracers"] = tracers.as_4D_array() return inputs diff --git a/tests/savepoint/translate/translate_tracer2d1l.py b/tests/savepoint/translate/translate_tracer2d1l.py index f3ad0f70..dcc9c104 100644 --- a/tests/savepoint/translate/translate_tracer2d1l.py +++ b/tests/savepoint/translate/translate_tracer2d1l.py @@ -1,10 +1,10 @@ import pytest -import ndsl.dsl.gt4py_utils as utils -from ndsl import Namelist, StencilFactory +from ndsl import Namelist, QuantityFactory, StencilFactory from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.stencils.testing import ParallelTranslate from pyFV3.stencils import FiniteVolumeTransport, TracerAdvection +from pyFV3.tracers import Tracers from pyFV3.utils.functional_validation import get_subset_func @@ -34,6 +34,10 @@ def __init__( self._base.in_vars["parameters"] = ["nq"] self._base.out_vars = self._base.in_vars["data_vars"] self.stencil_factory = stencil_factory + self._quantity_factory = QuantityFactory.from_backend( + sizer=stencil_factory.grid_indexing._sizer, + backend=stencil_factory.backend, + ) self.namelist = namelist self._subset = get_subset_func( self.grid.grid_indexing, @@ -46,11 +50,24 @@ def collect_input_data(self, serializer, savepoint): return input_data def compute_parallel(self, inputs, communicator): - self._base.make_storage_data_input_vars(inputs) - all_tracers = inputs["tracers"] - inputs["tracers"] = self.get_advected_tracer_dict( - inputs["tracers"], int(inputs.pop("nq")) + tracers = Tracers.make_from_4D_array( + quantity_factory=self._quantity_factory, + tracer_mapping=[ + "vapor", + "liquid", + "rain", + "ice", + "snow", + "graupel", + "o3mr", + "sgs_tke", + "cloud", + ], + tracer_data=inputs["tracers"], ) + self._base.make_storage_data_input_vars(inputs, dict_4d=False) + inputs.pop("tracers") + inputs.pop("nq") # Fortran NQ is intrinsic to Tracers (e.g Tracers.count) transport = FiniteVolumeTransport( stencil_factory=self.stencil_factory, quantity_factory=self.grid.quantity_factory, @@ -66,38 +83,26 @@ def compute_parallel(self, inputs, communicator): transport, self.grid.grid_data, communicator, - inputs["tracers"], + tracers, + exclude_tracers=["cloud"], ) inputs["x_mass_flux"] = inputs.pop("mfxd") inputs["y_mass_flux"] = inputs.pop("mfyd") inputs["x_courant"] = inputs.pop("cxd") inputs["y_courant"] = inputs.pop("cyd") - self.tracer_advection(**inputs) + self.tracer_advection(tracers=tracers, **inputs) inputs["mfxd"] = inputs.pop("x_mass_flux") inputs["mfyd"] = inputs.pop("y_mass_flux") inputs["cxd"] = inputs.pop("x_courant") inputs["cyd"] = inputs.pop("y_courant") - inputs[ - "tracers" - ] = all_tracers # some aren't advected, still need to be validated - # need to convert tracers dict to [x, y, z, n_tracer] array before subsetting + # Put back un-advected tracers + # Tracers have -1 on all cartesian because of NDSL padding + # Dev note: qcld is not advected in Pace dataset for some reason + inputs["tracers"] = tracers.as_4D_array() outputs = self._base.slice_output(inputs) outputs["tracers"] = self.subset_output("tracers", outputs["tracers"]) return outputs - def get_advected_tracer_dict(self, all_tracers, nq): - all_tracers = {**all_tracers} # make a new dict so we don't modify the input - properties = self.inputs["tracers"] - for name in utils.tracer_variables: - self.grid.quantity_dict_update( - all_tracers, - name, - dims=properties["dims"], - units=properties["units"], - ) - tracer_names = utils.tracer_variables[:nq] - return {name: all_tracers[name + "_quantity"] for name in tracer_names} - def compute_sequential(self, a, b): pytest.skip( f"{self.__class__} only has a mpirun implementation, "