diff --git a/examples/notebook/test_functionality.ipynb b/examples/notebook/test_functionality.ipynb index dc5ce3d8..e9aa077e 100644 --- a/examples/notebook/test_functionality.ipynb +++ b/examples/notebook/test_functionality.ipynb @@ -134,7 +134,7 @@ ")\n", "\n", "# useful for easily allocating distributed data storages (fields)\n", - "quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend)\n", + "quantity_factory = QuantityFactory(sizer=sizer, backend=backend)\n", "\n", "compilation_config = CompilationConfig(backend=backend, communicator=cs_communicator)\n", "\n", diff --git a/pyfv3/dycore_state.py b/pyfv3/dycore_state.py index 79c0fef2..9ea1d9ff 100644 --- a/pyfv3/dycore_state.py +++ b/pyfv3/dycore_state.py @@ -312,7 +312,7 @@ class DycoreState: rdg_var: Quantity = field( metadata={ "name": "gas constant for dry air over variable gravity (RDGAS / grav_var) for Whole Atmosphere calculations", - "units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units? + "units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units? "dims": [X_DIM, Y_DIM, Z_DIM], "intent": "inout", } @@ -368,7 +368,7 @@ def init_from_numpy_arrays( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), - gt4py_backend=backend, + backend=backend, ) state = cls(**dict_state) return state diff --git a/pyfv3/initialization/test_cases/initialize_baroclinic.py b/pyfv3/initialization/test_cases/initialize_baroclinic.py index 25784570..ef40e487 100644 --- a/pyfv3/initialization/test_cases/initialize_baroclinic.py +++ b/pyfv3/initialization/test_cases/initialize_baroclinic.py @@ -372,7 +372,7 @@ def init_baroclinic_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) comm.halo_update(state.phis, n_points=NHALO) diff --git a/pyfv3/initialization/test_cases/initialize_rossby.py b/pyfv3/initialization/test_cases/initialize_rossby.py index a5d85d34..d37683a9 100644 --- a/pyfv3/initialization/test_cases/initialize_rossby.py +++ b/pyfv3/initialization/test_cases/initialize_rossby.py @@ -203,7 +203,7 @@ def init_rossby_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) comm.halo_update(state.phis, n_points=NHALO) diff --git a/pyfv3/initialization/test_cases/initialize_tc.py b/pyfv3/initialization/test_cases/initialize_tc.py index 2c3d48ad..3b4622f6 100644 --- a/pyfv3/initialization/test_cases/initialize_tc.py +++ b/pyfv3/initialization/test_cases/initialize_tc.py @@ -570,7 +570,7 @@ def init_tc_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) return state diff --git a/pyfv3/stencils/delnflux.py b/pyfv3/stencils/delnflux.py index 4bc9653c..3acd1caf 100644 --- a/pyfv3/stencils/delnflux.py +++ b/pyfv3/stencils/delnflux.py @@ -27,7 +27,7 @@ def calc_damp(damp_c: Quantity, da_min: Float, nord: Quantity) -> Quantity: units="unknown", origin=damp_c.origin, extent=damp_c.extent, - gt4py_backend=damp_c.gt4py_backend, + backend=damp_c.backend, ) diff --git a/pyfv3/stencils/dyn_core.py b/pyfv3/stencils/dyn_core.py index fc5ad774..4d5453ca 100644 --- a/pyfv3/stencils/dyn_core.py +++ b/pyfv3/stencils/dyn_core.py @@ -10,9 +10,11 @@ import pyfv3.stencils.pe_halo as pe_halo import pyfv3.stencils.ray_fast as ray_fast import pyfv3.stencils.temperature_adjust as temperature_adjust -import pyfv3.stencils.rdg_adjust as rdg_adjust + +# import pyfv3.stencils.rdg_adjust as rdg_adjust import pyfv3.stencils.updatedzc as updatedzc import pyfv3.stencils.updatedzd as updatedzd +import pyfv3.stencils.wam as wam from ndsl import ( GridIndexing, Quantity, @@ -92,16 +94,6 @@ def zero_data( diss_estd = 0.0 -def average_gravity(grav_var: FloatField, grav_var_h: FloatField): - """ - Args: - grav_var (out): gravity field - grav_var_h (in): gravity value at height - """ - with computation(FORWARD), interval(...): - grav_var[0, 0, 0] = 0.5*(grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) - - def gz_from_surface_height_and_thicknesses( zs: FloatFieldIJ, delz: FloatField, gz: FloatField ): @@ -666,16 +658,17 @@ def __init__( grav_var_h=state.grav_var_h, ) - self._average_gravity = stencil_factory.from_origin_domain( - average_gravity, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) - self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain( - rdg_adjust.neg_rdgas_div_gravity, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) + # self._average_gravity = stencil_factory.from_origin_domain( + # wam.average_gravity_stencil_defn, + # origin=grid_indexing.origin_full(), + # domain=grid_indexing.domain_full(), + # ) + + # self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain( + # wam.neg_rdgas_div_gravity, + # origin=grid_indexing.origin_full(), + # domain=grid_indexing.domain_full(), + # ) # See divergence_damping.py, _get_da_min for explanation of this function @dace_inhibitor @@ -815,8 +808,9 @@ def __call__( if it == 0: self._halo_updaters.delp__pt.wait() self._halo_updaters.grav_var_h.update() - self._average_gravity(state.grav_var, state.grav_var_h) - self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var) + # ALREADY HAPPENING IN DYNAMICAL CORE CALL + # self._average_gravity(state.grav_var, state.grav_var_h) + # self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var) if it == n_split - 1 and end_step: if self.config.use_old_omega: @@ -892,6 +886,7 @@ def __call__( self._gz, self._pkc, state.omga, + state.grav_var, ) self._p_grad_c( @@ -979,6 +974,7 @@ def __call__( state.pk, state.peln, state.w, + state.grav_var, ) self._halo_updaters.zh.start() @@ -1058,5 +1054,5 @@ def __call__( self._heat_source, state.pt, delt_time_factor, - state.rdg_var + state.rdg_var, ) diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index 7a8c67c9..0f7db7fe 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -5,17 +5,18 @@ import ndsl.dsl.gt4py_utils as utils import pyfv3.stencils.moist_cv as moist_cv +import pyfv3.stencils.wam as wam from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater from ndsl.checkpointer import NullCheckpointer from ndsl.comm.mpi import MPI -from ndsl.constants import KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR, GRAV, RADIUS +from ndsl.constants import GRAV, KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR from ndsl.dsl.dace.orchestration import dace_inhibitor, orchestrate -from ndsl.dsl.gt4py import FORWARD, BACKWARD, PARALLEL, computation, interval -from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import Float, FloatField from ndsl.grid import DampingCoefficients, GridData from ndsl.logging import ndsl_log from ndsl.performance import NullTimer, Timer -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy_defn, set_value from ndsl.stencils.c2l_ord import CubedToLatLon from ndsl.typing import Checkpointer, Communicator from pyfv3._config import DynamicalCoreConfig @@ -25,7 +26,6 @@ from pyfv3.stencils.dyn_core import AcousticDynamics from pyfv3.stencils.neg_adj3 import AdjustNegativeTracerMixingRatio from pyfv3.stencils.remapping import LagrangianToEulerian -import pyfv3.stencils.rdg_adjust as rdg_adjust def pt_to_potential_density_pt( @@ -76,45 +76,6 @@ def fvdyn_temporaries( tmps[name] = quantity return tmps -def init_gravity(grav_var: FloatField): - """ - Args: - grav_var (out): gravity field - """ - with computation(PARALLEL), interval(...): - grav_var = GRAV - -def init_gravity_h(grav_var_h: FloatField): - """ - Args: - grav_var_h (out): gravity field - """ - with computation(PARALLEL), interval(...): - grav_var_h = GRAV - -def adjust_gravity( - grav_var: FloatField, - grav_var_h: FloatField, - phis: FloatFieldIJ, - delz: FloatField, -): - """ - Args: - grav_var (out): gravity field - grav_var_h (out): height based gravity - phis (out): - delz (out): - """ - with computation(FORWARD), interval(-1,None): - newrad = RADIUS + (phis/GRAV) - grav_var_h = GRAV*(RADIUS**2)/newrad**2 - - with computation(BACKWARD), interval(0,-1): - newrad = RADIUS + (phis/GRAV) - newrad = newrad - delz - grav_var_h = GRAV*(RADIUS**2)/newrad**2 - grav_var = 0.5*(grav_var_h[0, 0, 1] + grav_var_h[0, 0, 0]) - @dace_inhibitor def log_on_rank_0(msg: str): @@ -303,22 +264,17 @@ def __init__( domain=grid_indexing.domain_full(), ) self._init_gravity = stencil_factory.from_origin_domain( - init_gravity, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) - self._init_gravity_h = stencil_factory.from_origin_domain( - init_gravity_h, + set_value, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self._adjust_gravity = stencil_factory.from_origin_domain( - adjust_gravity, + wam.adjust_gravity, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(add=(0,0,1)), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self._adjust_rdg = stencil_factory.from_origin_domain( - rdg_adjust.neg_rdgas_div_gravity, + wam.neg_rdgas_div_gravity, origin=grid_indexing.origin_full(), domain=grid_indexing.domain_full(), ) @@ -387,7 +343,10 @@ def __init__( comm.get_scalar_halo_updater([full_xyz_spec]), state, ["omga"], comm=comm ) self._gravity_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([full_xyz_spec]), state, ["grav_var"], comm=comm + comm.get_scalar_halo_updater([full_xyz_spec]), + state, + ["grav_var"], + comm=comm, ) self._n_split = config.n_split self._k_split = config.k_split @@ -544,11 +503,14 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool): self._dp_initial, ) - self._init_gravity(state.grav_var) - self._init_gravity_h(state.grav_var_h) + self._init_gravity(state.grav_var, GRAV) + self._init_gravity(state.grav_var_h, GRAV) if self.config.enable_wam: - self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz) + self._adjust_gravity( + state.grav_var, state.grav_var_h, state.phis, state.delz + ) + self._gravity_halo_updater.update() self._adjust_rdg(state.rdg_var, state.grav_var) @@ -677,7 +639,9 @@ def _compute(self, state: DycoreState, timer: Timer): ) self._checkpoint_remapping_out(state) if self.config.enable_wam: - self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz) + self._adjust_gravity( + state.grav_var, state.grav_var_h, state.phis, state.delz + ) # TODO: can we pull this block out of the loop intead of # using an if-statement? if last_step: diff --git a/pyfv3/stencils/fv_subgridz.py b/pyfv3/stencils/fv_subgridz.py index 9090da16..6d6665b4 100644 --- a/pyfv3/stencils/fv_subgridz.py +++ b/pyfv3/stencils/fv_subgridz.py @@ -10,7 +10,6 @@ CP_VAP, CV_AIR, CV_VAP, - GRAV, RDGAS, X_DIM, Y_DIM, @@ -28,7 +27,6 @@ from gt4py.cartesian.gtscript import __INLINED # isort:skip RK = CP_AIR / RDGAS + 1.0 -G2 = 0.5 * GRAV T1_MIN = 160.0 T2_MIN = 165.0 T2_MAX = 315.0 @@ -95,6 +93,7 @@ def init( qo3mr: FloatField, qsgs_tke: FloatField, qcld: FloatField, + grav_var: FloatField, ): with computation(PARALLEL), interval(...): t0 = ta @@ -117,11 +116,11 @@ def init( cpm, cvm = standard_cm( cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel ) - gz = gzh[0, 0, 1] - G2 * delz + gz = gzh[0, 0, 1] - grav_var * delz tmp = tvol(gz, u0, v0, w0) static_energy = cpm * t0 + tmp total_energy = cvm * t0 + tmp - gzh = gzh[0, 0, 1] - GRAV * delz + gzh = gzh[0, 0, 1] - grav_var * delz @gtfunction @@ -903,6 +902,7 @@ def __call__( state.qo3mr, state.qsgs_tke, state.qcld, + state.grav_var, ) for n in range(self._m): diff --git a/pyfv3/stencils/moist_cv.py b/pyfv3/stencils/moist_cv.py index a0464734..7eef7152 100644 --- a/pyfv3/stencils/moist_cv.py +++ b/pyfv3/stencils/moist_cv.py @@ -122,6 +122,7 @@ def compute_pkz_func(delp, delz, pt, cappa, rdg_var): # TODO use the exponential form for closer answer matching return exp(cappa * log(rdg_var * delp / delz * pt)) + def moist_pkz( qvapor: FloatField, qliquid: FloatField, diff --git a/pyfv3/stencils/rdg_adjust.py b/pyfv3/stencils/rdg_adjust.py deleted file mode 100644 index d891fe1f..00000000 --- a/pyfv3/stencils/rdg_adjust.py +++ /dev/null @@ -1,17 +0,0 @@ -import ndsl.constants as constants -from ndsl.dsl.gt4py import FORWARD, computation, interval -from ndsl.dsl.typing import FloatField - - -def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): - """ - # JK TODO: Is there a better name than this? - Adjust rdg to be the negative RDGAS divided by the variable gravity - for Whole Atmosphere Modeling - - Args: - rdg (out): negative radiative gas divided by variable gravity - grav_var (in): variable gravity - """ - with computation(FORWARD), interval(...): - rdg = - constants.RDGAS / grav_var diff --git a/pyfv3/stencils/riem_solver3.py b/pyfv3/stencils/riem_solver3.py index 789904bf..38af40d5 100644 --- a/pyfv3/stencils/riem_solver3.py +++ b/pyfv3/stencils/riem_solver3.py @@ -31,6 +31,7 @@ def precompute( ptop: Float, peln1: Float, ptk: Float, + grav_var: FloatField, ): """ Args: @@ -48,6 +49,10 @@ def precompute( dz (out): p_gas (out): pressure defined at vertical mid levels due to gas-phase only, excluding condensates (Pa) + ptop (in): + peln1 (in): + ptk (in): + grav_var (in): variable gravity """ with computation(PARALLEL), interval(...): delta_mass = delp @@ -73,7 +78,7 @@ def precompute( pk3 = exp(constants.KAPPA * log_p_interface) with computation(PARALLEL), interval(...): gamma = 1.0 / (1.0 - cappa) # gamma, cp/cv - delta_mass = delta_mass * constants.RGRAV + delta_mass = delta_mass / grav_var with computation(PARALLEL), interval(0, -1): p_gas = (p_interface_gas[0, 0, 1] - p_interface_gas) / ( log_p_interface_gas[0, 0, 1] - log_p_interface_gas @@ -232,6 +237,7 @@ def __call__( pk: FloatField, log_p_interface: FloatField, w: FloatFieldIJ, + grav_var: FloatField, ): """ Solves for the nonhydrostatic terms for vertical velocity (w) @@ -262,6 +268,7 @@ def __call__( log_p_interface (out): logarithm of interface pressure, only written if last_call=True w (inout): vertical velocity + grav_var (in): variable gravity """ # TODO: propagate variable renaming for these into stencils here and @@ -298,6 +305,7 @@ def __call__( ptop, peln1, ptk, + grav_var, ) self._sim1_solve( diff --git a/pyfv3/stencils/riem_solver_c.py b/pyfv3/stencils/riem_solver_c.py index b2170a56..278935e3 100644 --- a/pyfv3/stencils/riem_solver_c.py +++ b/pyfv3/stencils/riem_solver_c.py @@ -22,6 +22,7 @@ def precompute( gm: FloatField, pm: FloatField, ptop: Float, + grav_var: FloatField, ): """ Args: @@ -40,6 +41,8 @@ def precompute( pm (out): hydrostatic cell mean pressure, derivation in documentation (Chapter 4? 7?) TODO: identify chapter reference, will be sent by Lucas + ptop (in): + grav_var (in): variable gravity """ with computation(PARALLEL), interval(...): dm = delpc @@ -57,7 +60,7 @@ def precompute( dz = gz[0, 0, 1] - gz with computation(PARALLEL), interval(...): gm = 1.0 / (1.0 - cappa) - dm /= constants.GRAV + dm /= grav_var with computation(PARALLEL), interval(0, -1): # (1) From \partial p*/\partial z = -\rho g, we can separate and integrate # over a layer to get @@ -87,6 +90,7 @@ def finalize( pef: FloatField, gz: FloatField, ptop: Float, + grav_var: FloatField, ): """ Enforce vertical boundary conditions. @@ -101,6 +105,8 @@ def finalize( dz (in): pef (out): gz (out): + ptop (in): + grav_var (in): variable gravity """ with computation(PARALLEL): with interval(0, 1): @@ -111,7 +117,7 @@ def finalize( with interval(-1, None): gz = hs with interval(0, -1): - gz = gz[0, 0, 1] - dz * constants.GRAV + gz = gz[0, 0, 1] - dz * grav_var class NonhydrostaticVerticalSolverCGrid: @@ -201,6 +207,7 @@ def __call__( gz: FloatField, pef: FloatField, w3: FloatField, + grav_var: FloatField, ): """ Solves for the nonhydrostatic terms for vertical velocity (w) @@ -219,6 +226,7 @@ def __call__( gz (inout): geopotential height pef (out): full hydrostatic pressure w3 (in): vertical velocity + grav_var (in): variable gravity """ # TODO: integrate these notes into comments/code, double-check: @@ -266,4 +274,6 @@ def __call__( ws, ) # pe is nonhydrostatic perturbation pressure defined on interfaces - self._finalize_stencil(self._pe, self._pem, hs, self._dz, pef, gz, ptop) + self._finalize_stencil( + self._pe, self._pem, hs, self._dz, pef, gz, ptop, grav_var + ) diff --git a/pyfv3/stencils/wam.py b/pyfv3/stencils/wam.py new file mode 100644 index 00000000..bfcefe57 --- /dev/null +++ b/pyfv3/stencils/wam.py @@ -0,0 +1,88 @@ +from ndsl import StencilFactory +from ndsl.constants import GRAV, RADIUS, RDGAS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation +from ndsl.dsl.gt4py import function as gtfunction +from ndsl.dsl.gt4py import interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ + + +@gtfunction +def average_gravity(grav_var: FloatField, grav_var_h: FloatField): + grav_var = 0.5 * (grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) + return grav_var + + +# May not need this stencil at all +def average_gravity_stencil_defn(grav_var: FloatField, grav_var_h: FloatField): + """ + Args: + grav_var (out): gravity field + grav_var_h (in): gravity value at interfaces + """ + with computation(FORWARD), interval(...): + grav_var = average_gravity(grav_var, grav_var_h) + + +def adjust_gravity( + grav_var: FloatField, + grav_var_h: FloatField, + phis: FloatFieldIJ, + delz: FloatField, +): + """ + Args: + grav_var (out): gravity field + grav_var_h (out): gravity value at interfaces + phis (in): geopotential + delz (in): change in vertical height + """ + with computation(FORWARD), interval(-1, None): + newrad = RADIUS + (phis / GRAV) + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + + with computation(BACKWARD), interval(0, -1): + newrad = RADIUS + (phis / GRAV) + newrad = newrad - delz + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + grav_var = average_gravity(grav_var, grav_var_h) + + +def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): + """ + # JK TODO: Is there a better name than this? + Adjust rdg to be the negative RDGAS divided by the variable gravity + for Whole Atmosphere Modeling + + Args: + rdg (out): negative radiative gas divided by variable gravity + grav_var (in): variable gravity + """ + with computation(FORWARD), interval(...): + rdg = -RDGAS / grav_var + + +class WholeAtmos: + def __init__(self, stencil_factory: StencilFactory): + self.constructed_average_gravity_stencil = stencil_factory.from_dims_halo( + func=average_gravity_stencil_defn, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.constructed_adjust_gravity_stencil = stencil_factory.from_dims_halo( + func=adjust_gravity, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.constructed_neg_rdgas_div_gravity_stencil = stencil_factory.from_dims_halo( + func=neg_rdgas_div_gravity, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__( + self, + grav_var: FloatField, + grav_var_h: FloatField, + rdg: FloatField, + phis: FloatFieldIJ, + delz: FloatField, + ): + self.constructed_adjust_gravity_stencil(grav_var, grav_var_h, phis, delz) + self.constructed_neg_rdgas_div_gravity_stencil(rdg, grav_var) diff --git a/pyfv3/wrappers/geos_wrapper.py b/pyfv3/wrappers/geos_wrapper.py index ab6b4bd3..271c29cb 100644 --- a/pyfv3/wrappers/geos_wrapper.py +++ b/pyfv3/wrappers/geos_wrapper.py @@ -143,7 +143,7 @@ def __init__( sizer = SubtileGridSizer.from_namelist( self.namelist, partitioner.tile, self.communicator.tile.rank ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer=sizer, backend=backend) # set up the metric terms and grid data metric_terms = MetricTerms( diff --git a/tests/main/test_wam.py b/tests/main/test_wam.py index e8138c8d..78d4decb 100644 --- a/tests/main/test_wam.py +++ b/tests/main/test_wam.py @@ -1,8 +1,10 @@ -from pathlib import Path -from dataclasses import field +import copy from datetime import timedelta +from pathlib import Path from typing import Tuple -import copy + +# use numpy for now until I figure out how to use FloatField +import numpy as np # JK TODO: Should I be using xumpy? import pyfv3.initialization.analytic_init as ai from ndsl import ( @@ -19,27 +21,27 @@ SubtileGridSizer, TilePartitioner, ) -from ndsl.grid import DampingCoefficients, GridData, MetricTerms -from ndsl.dsl.typing import Float, FloatField -from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM +from ndsl.constants import GRAV, RADIUS, RDGAS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.gt4py import stencil -from pyfv3 import DynamicalCore, DynamicalCoreConfig, DycoreState -from pyfv3.initialization import init_utils +from ndsl.dsl.typing import Float +from ndsl.grid import DampingCoefficients, GridData, MetricTerms +from ndsl.stencils.basic_operations import set_value +from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig from pyfv3.initialization.analytic_init import AnalyticCase -from pyfv3.stencils.dyn_core import AcousticDynamics -from pyfv3.stencils.rdg_adjust import neg_rdgas_div_gravity -from pyfv3.stencils.fv_dynamics import adjust_gravity, init_gravity, init_gravity_h -from pyfv3.stencils.dyn_core import average_gravity, compute_geopotential +from pyfv3.stencils.dyn_core import AcousticDynamics, compute_geopotential +from pyfv3.stencils.wam import ( + adjust_gravity, + average_gravity_stencil_defn, + neg_rdgas_div_gravity, +) -# use numpy for now until I figure out how to use FloatField -import numpy as np # JK TODO: Should I be using xumpy? -# JK NOTE TODO: Just sticking things in here for now, will distribute them into their right -# places in the future. +# JK NOTE TODO: Just sticking things in here for now, +# will distribute them into their right places in the future. def test_enable_wam() -> None: - # Set up dycore config with enable_wam = True + # Set up dycore config with enable_wam = True # Set up dycore # Check that gravity is variable and grav_var and grav_var_h are used somehow? # maybe call to compute_geopotential? @@ -51,8 +53,8 @@ def test_enable_wam() -> None: # TODO assert that something is different between wam_enabled = False # JK NOTE: I have to find out what to test..... - - assert False # TODO + + assert False # TODO ############################ dyncore_state.py @@ -117,7 +119,7 @@ def setup_dycore_state() -> DycoreState: tile_rank=communicator.tile.rank, ) quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - #eta_file = Path("tests/data/eta79.nc") + # eta_file = Path("tests/data/eta79.nc") eta_file = Path(__file__).resolve().parents[1] / "data" / "eta79.nc" metric_terms = MetricTerms( quantity_factory=quantity_factory, @@ -155,7 +157,7 @@ def test_neg_rdgas_div_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -173,17 +175,19 @@ def test_neg_rdgas_div_gravity() -> None: backend=example_backend, ) - init_gravity_stencil = stencil(backend=example_backend, definition=init_gravity) - init_gravity_stencil(grav_var) + init_gravity_stencil = stencil(backend=example_backend, definition=set_value) + init_gravity_stencil(grav_var, GRAV) expected_grav_var_np = copy.deepcopy(grav_var.field[:]) - neg_rdgas_div_gravity_stencil = stencil(backend=example_backend, definition=neg_rdgas_div_gravity) + neg_rdgas_div_gravity_stencil = stencil( + backend=example_backend, definition=neg_rdgas_div_gravity + ) neg_rdgas_div_gravity_stencil(rdg, grav_var) # grav_var_h should be unchanged by the stencil assert np.array_equal(grav_var.field[:], expected_grav_var_np) - expected_rdg_np = - (RDGAS / expected_grav_var_np[:]) + expected_rdg_np = -(RDGAS / expected_grav_var_np[:]) assert np.array_equal(rdg.field[:], expected_rdg_np) @@ -194,7 +198,7 @@ def test_average_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -204,7 +208,7 @@ def test_average_gravity() -> None: backend=example_backend, ) - grav_var_h_np = np.random.random((nx, ny, nz+1)) + grav_var_h_np = np.random.random((nx, ny, nz + 1)) expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) grav_var_h = Quantity( data=grav_var_h_np, @@ -213,14 +217,18 @@ def test_average_gravity() -> None: number_of_halo_points=n_halos, backend=example_backend, ) - - average_gravity_numpy = stencil(backend=example_backend, definition=average_gravity) + + average_gravity_numpy = stencil( + backend=example_backend, definition=average_gravity_stencil_defn + ) average_gravity_numpy(grav_var, grav_var_h) # grav_var_h should be unchanged by the stencil assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) - expected_grav_var_np = (expected_grav_var_h_np[:,:,:-1]+expected_grav_var_h_np[:,:,1:]) / 2 + expected_grav_var_np = ( + expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] + ) / 2 assert np.array_equal(grav_var.field[:], expected_grav_var_np) @@ -231,9 +239,9 @@ def test_compute_geopotential() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" - grav_var_h_np = np.random.random((nx, ny, nz+1)) + grav_var_h_np = np.random.random((nx, ny, nz + 1)) expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) grav_var_h = Quantity( data=grav_var_h_np, @@ -263,7 +271,9 @@ def test_compute_geopotential() -> None: backend=example_backend, ) - compute_geopotential_np = stencil(backend=example_backend, definition=compute_geopotential) + compute_geopotential_np = stencil( + backend=example_backend, definition=compute_geopotential + ) compute_geopotential_np(zh, gz, grav_var_h) # Check that zh and grav_var_h are unchanged @@ -273,7 +283,7 @@ def test_compute_geopotential() -> None: # Check that gz = zh * grav_var_h assert not np.array_equal(gz.field[:], gz_np_copy) # JK TODO: Is the expected_gz_np calculated correctly? double check with fortran or frank... - expected_gz_np = zh_np * grav_var_h_np[:,:,:-1] + expected_gz_np = zh_np * grav_var_h_np[:, :, :-1] assert np.array_equal(gz.field[:], expected_gz_np) @@ -388,98 +398,113 @@ def setup_acoustic_dynamics(npx, npy, n_halo) -> Tuple[AcousticDynamics, DycoreS # JK TODO simplify this please, if possible... return dycore.acoustic_dynamics, state -def test_acoustic_dynamics_init_average_gravity() -> None: - # Check that average gravity is called/used in AcousticDynamics initialization - nx = 12 - ny = 12 - nz = 79 - n_halo = 3 +# def test_acoustic_dynamics_init_average_gravity() -> None: +# # Check that average gravity is called/used in AcousticDynamics initialization - # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, _ = setup_acoustic_dynamics(nx-(2*n_halo)+1, ny-(2*n_halo)+1, n_halo) # nz is hard-coded to 79 +# nx = 12 +# ny = 12 +# nz = 79 +# n_halo = 3 - # JK TODO: switch from example grav_var, grav_var_h to state.grav_var, state.grav_var_h +# # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 +# ac_dyn, _ = setup_acoustic_dynamics( +# nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo +# ) # nz is hard-coded to 79 - example_dims = ["I", "J", "K"] - example_backend="numpy" +# # JK TODO: switch from example grav_var, grav_var_h to state.grav_var, state.grav_var_h - grav_var = Quantity( - data=np.zeros((nx, ny, nz)), - dims=example_dims, - units="grav_var units", - number_of_halo_points=n_halo, - backend=example_backend, - ) +# example_dims = ["I", "J", "K"] +# example_backend = "numpy" - grav_var_h_np = np.random.random((nx, ny, nz+1)) - expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) - grav_var_h = Quantity( - data=grav_var_h_np, - dims=example_dims, - units="grav_var_h units", - number_of_halo_points=n_halo, - backend=example_backend, - ) +# grav_var = Quantity( +# data=np.zeros((nx, ny, nz)), +# dims=example_dims, +# units="grav_var units", +# number_of_halo_points=n_halo, +# backend=example_backend, +# ) - # Call ad_dyn._average_gravity. This is what we're testing. - ac_dyn._average_gravity(grav_var, grav_var_h) +# grav_var_h_np = np.random.random((nx, ny, nz + 1)) +# expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) +# grav_var_h = Quantity( +# data=grav_var_h_np, +# dims=example_dims, +# units="grav_var_h units", +# number_of_halo_points=n_halo, +# backend=example_backend, +# ) - # grav_var_h should be unchanged by the stencil - assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) +# # Call ad_dyn._average_gravity. This is what we're testing. +# ac_dyn._average_gravity(grav_var, grav_var_h) - expected_grav_var_np = (expected_grav_var_h_np[:,:,:-1]+expected_grav_var_h_np[:,:,1:]) / 2 - assert np.array_equal(grav_var.field[:], expected_grav_var_np) +# # grav_var_h should be unchanged by the stencil +# assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) + +# expected_grav_var_np = ( +# expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] +# ) / 2 +# assert np.array_equal(grav_var.field[:], expected_grav_var_np) + + +# def test_acoustic_dynamics_call_average_gravity() -> None: +# # Check that average gravity is called/used in AcousticDynamics call +# nx = 12 +# ny = 12 +# nz = 79 +# n_halo = 3 +# timestep = 225 # JK TODO: Is this right? + +# # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 +# ac_dyn, state = setup_acoustic_dynamics( +# nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo +# ) # nz is hard-coded to 79 +# init_grav_var_np = copy.deepcopy(state.grav_var.field) +# init_grav_var_h_np = copy.deepcopy(state.grav_var_h.field) -def test_acoustic_dynamics_call_average_gravity() -> None: - # Check that average gravity is called/used in AcousticDynamics call - nx = 12 - ny = 12 - nz = 79 - n_halo = 3 - timestep = 225 # JK TODO: Is this right? +# ac_dyn(state, timestep) - # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, state = setup_acoustic_dynamics(nx-(2*n_halo)+1, ny-(2*n_halo)+1, n_halo) # nz is hard-coded to 79 +# # The state.grav_var_h should be unchanged by the stencil. +# assert np.array_equal(state.grav_var_h.field[:], init_grav_var_h_np) - init_grav_var_np = copy.deepcopy(state.grav_var.field) - init_grav_var_h_np = copy.deepcopy(state.grav_var_h.field) +# # Check that the state.grav_var values match expectation: +# expected_grav_var_np = ( +# init_grav_var_h_np[:, :, :-1] + init_grav_var_h_np[:, :, 1:] +# ) / 2 +# assert np.array_equal(state.grav_var.field[:], expected_grav_var_np) - ac_dyn(state, timestep) - - # The state.grav_var_h should be unchanged by the stencil. - assert np.array_equal(state.grav_var_h.field[:], init_grav_var_h_np) - # Check that the state.grav_var values match expectation: - expected_grav_var_np = (init_grav_var_h_np[:,:,:-1]+init_grav_var_h_np[:,:,1:]) / 2 - assert np.array_equal(state.grav_var.field[:], expected_grav_var_np) +# """ +# E + where False = ( -""" -E + where False = ( +# array([ +# [[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), -array([ -[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), +# array([[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]])) -array([[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]])) +# E + where = np.array_equal -E + where = np.array_equal +# """ -""" +# ############################ fv_dynamics.py -############################ fv_dynamics.py def test_init_gravity() -> None: # Check that init_gravity sets 3d grav_var to the constant GRAV for all vals backend = "numpy" nx_tile, ny_tile, nz, n_halo = 6, 6, 2, 3 - layout = (1,1) + layout = (1, 1) partitioner = CubedSpherePartitioner(TilePartitioner(layout)) mpi_comm = NullComm(rank=0, total_ranks=6, fill_value=0.0) communicator = CubedSphereCommunicator(mpi_comm, partitioner) - compilation_config = CompilationConfig(backend=backend, rebuild=False, validate_args=True) + compilation_config = CompilationConfig( + backend=backend, rebuild=False, validate_args=True + ) dace_config = DaceConfig(communicator=communicator, backend=backend) - stencil_config = StencilConfig(compilation_config=compilation_config, dace_config=dace_config) + stencil_config = StencilConfig( + compilation_config=compilation_config, dace_config=dace_config + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx_tile, ny_tile=ny_tile, @@ -487,23 +512,32 @@ def test_init_gravity() -> None: n_halo=n_halo, layout=layout, tile_partitioner=partitioner.tile, - tile_rank=communicator.tile.rank + tile_rank=communicator.tile.rank, + ) + grid_indexing = GridIndexing.from_sizer_and_communicator( + sizer=sizer, comm=communicator ) - grid_indexing = GridIndexing.from_sizer_and_communicator(sizer=sizer, comm=communicator) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - init_gravity_stencil = stencil_factory.from_dims_halo( - init_gravity, - compute_dims=[X_DIM, Y_DIM, Z_DIM], - compute_halos=(n_halo, n_halo), + init_gravity_stencil = stencil_factory.from_origin_domain( + set_value, + origin=grid_indexing.origin_full(), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) grav_var: Quantity = quantity_factory.zeros( [X_DIM, Y_DIM, Z_DIM], units="test", dtype=Float, ) - init_gravity_stencil(grav_var) + grav_var_h: Quantity = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_INTERFACE_DIM], + units="test", + dtype=Float, + ) + init_gravity_stencil(grav_var, GRAV) + init_gravity_stencil(grav_var_h, GRAV) assert np.all(grav_var.field == GRAV) + assert np.all(grav_var_h.field == GRAV) # JK TODO: There's so much setup... Find a simpler way to set of stencil and quantity? @@ -515,7 +549,7 @@ def test_adjust_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -526,7 +560,7 @@ def test_adjust_gravity() -> None: ) grav_var_h = Quantity( - data=np.zeros((nx, ny, nz+1)), + data=np.zeros((nx, ny, nz + 1)), dims=example_dims, units="grav_var_h units", number_of_halo_points=n_halos, @@ -559,24 +593,29 @@ def test_adjust_gravity() -> None: # Check that grav_var and grav_var_h are set appropriately newrad = np.zeros((nx, ny)) expected_grav_var_np = np.zeros((nx, ny, nz)) - expected_grav_var_h_np = np.zeros((nx, ny, nz+1)) - + expected_grav_var_h_np = np.zeros((nx, ny, nz + 1)) + assert np.array_equal(grav_var_h.field[:].shape, expected_grav_var_h_np.shape) for j in range(ny): for i in range(nx): for k in range(nz, -1, -1): if k == nz: - newrad[i,j] = RADIUS + (phis.field[i,j]/GRAV) + newrad[i, j] = RADIUS + (phis.field[i, j] / GRAV) else: - newrad[i,j] = newrad[i,j] - delz.field[i,j,k] - expected_grav_var_h_np[i,j,k] = GRAV*((RADIUS**2)/(newrad[i,j]**2)) + newrad[i, j] = newrad[i, j] - delz.field[i, j, k] + expected_grav_var_h_np[i, j, k] = GRAV * ( + (RADIUS**2) / (newrad[i, j] ** 2) + ) if k < nz: - expected_grav_var_np[i,j,k] = 0.5*(expected_grav_var_h_np[i,j,k+1]+expected_grav_var_h_np[i,j,k]) + expected_grav_var_np[i, j, k] = 0.5 * ( + expected_grav_var_h_np[i, j, k + 1] + + expected_grav_var_h_np[i, j, k] + ) # JK TODO: is there some rtol/atol threshold? the np vs stencil calculations are close but not exact. assert np.allclose(grav_var_h.field[:], expected_grav_var_h_np) assert np.allclose(grav_var.field[:], expected_grav_var_np) - + # TODO JK NOTE to self --- checkout log_on_rank_0 for values that might be useful for test (possibly) diff --git a/tests/mpi/test_doubly_periodic.py b/tests/mpi/test_doubly_periodic.py index fc825cb9..42215d66 100644 --- a/tests/mpi/test_doubly_periodic.py +++ b/tests/mpi/test_doubly_periodic.py @@ -97,7 +97,7 @@ def test_dycore_runs_one_step() -> None: grid_indexing = GridIndexing.from_sizer_and_communicator( sizer=sizer, comm=communicator ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer=sizer, backend=backend) metric_terms = MetricTerms( quantity_factory=quantity_factory, communicator=communicator, diff --git a/tests/savepoint/translate/translate_init_case.py b/tests/savepoint/translate/translate_init_case.py index 0e41f3a5..75f4336a 100644 --- a/tests/savepoint/translate/translate_init_case.py +++ b/tests/savepoint/translate/translate_init_case.py @@ -204,7 +204,7 @@ def compute_parallel(self, inputs, communicator): properties["units"], origin=self.grid.sizer.get_origin(dims), extent=self.grid.sizer.get_extent(dims), - gt4py_backend=self.stencil_factory.backend, + backend=self.stencil_factory.backend, ) metric_terms = MetricTerms.from_tile_sizing( @@ -225,12 +225,9 @@ def compute_parallel(self, inputs, communicator): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend( - sizer, backend=self.stencil_factory.backend - ) + quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) grid_data = GridData.new_from_metric_terms(metric_terms) - quantity_factory = QuantityFactory() state = analytic_init.init_analytic_state( analytic_init_case="baroclinic",