From 0f57ddbd675c8274be337bdb20550c989c0a12a5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 6 Jan 2026 13:49:57 +0100 Subject: [PATCH 1/7] refactor: use new constructor of `QuantityFactory` Prefer the new constructor of `QuantityFactory` over the deprecated call to `QuantityFactory.from_backend(...)`. This removes a bunch of deprecation warnings in tests. See https://github.com/NOAA-GFDL/NDSL/pull/228 for context. --- examples/notebook/test_functionality.ipynb | 2 +- pyfv3/wrappers/geos_wrapper.py | 2 +- tests/mpi/test_doubly_periodic.py | 2 +- tests/savepoint/translate/translate_init_case.py | 5 +---- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/notebook/test_functionality.ipynb b/examples/notebook/test_functionality.ipynb index dc5ce3d8..e9aa077e 100644 --- a/examples/notebook/test_functionality.ipynb +++ b/examples/notebook/test_functionality.ipynb @@ -134,7 +134,7 @@ ")\n", "\n", "# useful for easily allocating distributed data storages (fields)\n", - "quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend)\n", + "quantity_factory = QuantityFactory(sizer=sizer, backend=backend)\n", "\n", "compilation_config = CompilationConfig(backend=backend, communicator=cs_communicator)\n", "\n", diff --git a/pyfv3/wrappers/geos_wrapper.py b/pyfv3/wrappers/geos_wrapper.py index ab6b4bd3..271c29cb 100644 --- a/pyfv3/wrappers/geos_wrapper.py +++ b/pyfv3/wrappers/geos_wrapper.py @@ -143,7 +143,7 @@ def __init__( sizer = SubtileGridSizer.from_namelist( self.namelist, partitioner.tile, self.communicator.tile.rank ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer=sizer, backend=backend) # set up the metric terms and grid data metric_terms = MetricTerms( diff --git a/tests/mpi/test_doubly_periodic.py b/tests/mpi/test_doubly_periodic.py index fc825cb9..42215d66 100644 --- a/tests/mpi/test_doubly_periodic.py +++ b/tests/mpi/test_doubly_periodic.py @@ -97,7 +97,7 @@ def test_dycore_runs_one_step() -> None: grid_indexing = GridIndexing.from_sizer_and_communicator( sizer=sizer, comm=communicator ) - quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) + quantity_factory = QuantityFactory(sizer=sizer, backend=backend) metric_terms = MetricTerms( quantity_factory=quantity_factory, communicator=communicator, diff --git a/tests/savepoint/translate/translate_init_case.py b/tests/savepoint/translate/translate_init_case.py index 0e41f3a5..a9207353 100644 --- a/tests/savepoint/translate/translate_init_case.py +++ b/tests/savepoint/translate/translate_init_case.py @@ -225,12 +225,9 @@ def compute_parallel(self, inputs, communicator): tile_rank=communicator.tile.rank, ) - quantity_factory = QuantityFactory.from_backend( - sizer, backend=self.stencil_factory.backend - ) + quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) grid_data = GridData.new_from_metric_terms(metric_terms) - quantity_factory = QuantityFactory() state = analytic_init.init_analytic_state( analytic_init_case="baroclinic", From 594739f11a79981219f406943e9c353256a18de1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 6 Jan 2026 14:36:26 +0100 Subject: [PATCH 2/7] refactor: prefer `backend` over `gt4py_backend` On `Quantities`, we have currently allow both `backend` and `gt4py_backend` where the later is deprecated and about to be removed. See https://github.com/NOAA-GFDL/NDSL/pull/312 and https://github.com/NOAA-GFDL/NDSL/pull/314 for context. --- pyfv3/dycore_state.py | 2 +- pyfv3/initialization/test_cases/initialize_baroclinic.py | 2 +- pyfv3/initialization/test_cases/initialize_rossby.py | 2 +- pyfv3/initialization/test_cases/initialize_tc.py | 2 +- pyfv3/stencils/delnflux.py | 2 +- tests/savepoint/translate/translate_init_case.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pyfv3/dycore_state.py b/pyfv3/dycore_state.py index 6c9fb4fa..61a4b0f7 100644 --- a/pyfv3/dycore_state.py +++ b/pyfv3/dycore_state.py @@ -344,7 +344,7 @@ def init_from_numpy_arrays( _field.metadata["units"], origin=sizer.get_origin(dims), extent=sizer.get_extent(dims), - gt4py_backend=backend, + backend=backend, ) state = cls(**dict_state) return state diff --git a/pyfv3/initialization/test_cases/initialize_baroclinic.py b/pyfv3/initialization/test_cases/initialize_baroclinic.py index ee846a8c..739430ea 100644 --- a/pyfv3/initialization/test_cases/initialize_baroclinic.py +++ b/pyfv3/initialization/test_cases/initialize_baroclinic.py @@ -369,7 +369,7 @@ def init_baroclinic_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) comm.halo_update(state.phis, n_points=NHALO) diff --git a/pyfv3/initialization/test_cases/initialize_rossby.py b/pyfv3/initialization/test_cases/initialize_rossby.py index a5d85d34..d37683a9 100644 --- a/pyfv3/initialization/test_cases/initialize_rossby.py +++ b/pyfv3/initialization/test_cases/initialize_rossby.py @@ -203,7 +203,7 @@ def init_rossby_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) comm.halo_update(state.phis, n_points=NHALO) diff --git a/pyfv3/initialization/test_cases/initialize_tc.py b/pyfv3/initialization/test_cases/initialize_tc.py index 2c3d48ad..3b4622f6 100644 --- a/pyfv3/initialization/test_cases/initialize_tc.py +++ b/pyfv3/initialization/test_cases/initialize_tc.py @@ -570,7 +570,7 @@ def init_tc_state( state = DycoreState.init_from_numpy_arrays( numpy_state.__dict__, sizer=quantity_factory.sizer, - backend=sample_quantity.metadata.gt4py_backend, + backend=sample_quantity.metadata.backend, ) return state diff --git a/pyfv3/stencils/delnflux.py b/pyfv3/stencils/delnflux.py index 4bc9653c..3acd1caf 100644 --- a/pyfv3/stencils/delnflux.py +++ b/pyfv3/stencils/delnflux.py @@ -27,7 +27,7 @@ def calc_damp(damp_c: Quantity, da_min: Float, nord: Quantity) -> Quantity: units="unknown", origin=damp_c.origin, extent=damp_c.extent, - gt4py_backend=damp_c.gt4py_backend, + backend=damp_c.backend, ) diff --git a/tests/savepoint/translate/translate_init_case.py b/tests/savepoint/translate/translate_init_case.py index 0e41f3a5..f1831d0b 100644 --- a/tests/savepoint/translate/translate_init_case.py +++ b/tests/savepoint/translate/translate_init_case.py @@ -204,7 +204,7 @@ def compute_parallel(self, inputs, communicator): properties["units"], origin=self.grid.sizer.get_origin(dims), extent=self.grid.sizer.get_extent(dims), - gt4py_backend=self.stencil_factory.backend, + backend=self.stencil_factory.backend, ) metric_terms = MetricTerms.from_tile_sizing( From d5408fc405732ba37e4c8b782653d4e489ee1e75 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 7 Jan 2026 13:46:01 -0500 Subject: [PATCH 3/7] Removing init_gravity stencil definition and using set_value from NDSL instead --- pyfv3/stencils/fv_dynamics.py | 32 ++++++-------------------------- tests/main/test_wam.py | 26 +++++++++++++++++--------- 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index 4314c9ce..0cf24fb6 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -15,7 +15,7 @@ from ndsl.grid import DampingCoefficients, GridData from ndsl.logging import ndsl_log from ndsl.performance import NullTimer, Timer -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy_defn, set_value from ndsl.stencils.c2l_ord import CubedToLatLon from ndsl.typing import Checkpointer, Communicator from pyfv3._config import DynamicalCoreConfig @@ -75,22 +75,6 @@ def fvdyn_temporaries( tmps[name] = quantity return tmps -def init_gravity(grav_var: FloatField): - """ - Args: - grav_var (out): gravity field - """ - with computation(PARALLEL), interval(...): - grav_var = GRAV - -def init_gravity_h(grav_var_h: FloatField): - """ - Args: - grav_var_h (out): gravity field - """ - with computation(PARALLEL), interval(...): - grav_var_h = GRAV - def adjust_gravity( grav_var: FloatField, grav_var_h: FloatField, @@ -302,14 +286,9 @@ def __init__( domain=grid_indexing.domain_full(), ) self._init_gravity = stencil_factory.from_origin_domain( - init_gravity, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) - self._init_gravity_h = stencil_factory.from_origin_domain( - init_gravity_h, + set_value, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), + domain=grid_indexing.domain_full(add=(0,0,1)), ) self._adjust_gravity = stencil_factory.from_origin_domain( adjust_gravity, @@ -538,8 +517,9 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool): self._dp_initial, ) - self._init_gravity(state.grav_var) - self._init_gravity_h(state.grav_var_h) + # self._init_gravity(state.grav_var, state.grav_var_h) + self._init_gravity(state.grav_var, GRAV) + self._init_gravity(state.grav_var_h, GRAV) if self.config.enable_wam: self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz) diff --git a/tests/main/test_wam.py b/tests/main/test_wam.py index d830802c..582f7947 100644 --- a/tests/main/test_wam.py +++ b/tests/main/test_wam.py @@ -21,13 +21,14 @@ ) from ndsl.grid import DampingCoefficients, GridData, MetricTerms from ndsl.dsl.typing import Float, FloatField -from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM +from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.gt4py import stencil +from ndsl.stencils.basic_operations import set_value from pyfv3 import DynamicalCore, DynamicalCoreConfig, DycoreState from pyfv3.initialization import init_utils from pyfv3.initialization.analytic_init import AnalyticCase from pyfv3.stencils.dyn_core import AcousticDynamics -from pyfv3.stencils.fv_dynamics import adjust_gravity, init_gravity, init_gravity_h +from pyfv3.stencils.fv_dynamics import adjust_gravity from pyfv3.stencils.dyn_core import average_gravity, compute_geopotential # use numpy for now until I figure out how to use FloatField @@ -172,8 +173,8 @@ def test_neg_rdgas_div_gravity() -> None: backend=example_backend, ) - init_gravity_stencil = stencil(backend=example_backend, definition=init_gravity) - init_gravity_stencil(grav_var) + init_gravity_stencil = stencil(backend=example_backend, definition=set_value) + init_gravity_stencil(grav_var, GRAV) expected_grav_var_np = copy.deepcopy(grav_var.field[:]) neg_rdgas_div_gravity_stencil = stencil(backend=example_backend, definition=neg_rdgas_div_gravity) @@ -491,18 +492,25 @@ def test_init_gravity() -> None: grid_indexing = GridIndexing.from_sizer_and_communicator(sizer=sizer, comm=communicator) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - init_gravity_stencil = stencil_factory.from_dims_halo( - init_gravity, - compute_dims=[X_DIM, Y_DIM, Z_DIM], - compute_halos=(n_halo, n_halo), + init_gravity_stencil = stencil_factory.from_origin_domain( + set_value, + origin=grid_indexing.origin_full(), + domain=grid_indexing.domain_full(add=(0,0,1)), ) grav_var: Quantity = quantity_factory.zeros( [X_DIM, Y_DIM, Z_DIM], units="test", dtype=Float, ) - init_gravity_stencil(grav_var) + grav_var_h: Quantity = quantity_factory.zeros( + [X_DIM, Y_DIM, Z_INTERFACE_DIM], + units="test", + dtype=Float, + ) + init_gravity_stencil(grav_var, GRAV) + init_gravity_stencil(grav_var_h, GRAV) assert np.all(grav_var.field == GRAV) + assert np.all(grav_var_h.field == GRAV) # JK TODO: There's so much setup... Find a simpler way to set of stencil and quantity? From 1e2c9dc426de24b7442df0b2964884f40dcc03b7 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 8 Jan 2026 11:00:06 -0500 Subject: [PATCH 4/7] Moving gravity operations to its own module for reusability --- pyfv3/dycore_state.py | 2 +- pyfv3/stencils/dyn_core.py | 16 +--- pyfv3/stencils/fv_dynamics.py | 51 +++++-------- pyfv3/stencils/gravity.py | 45 +++++++++++ pyfv3/stencils/moist_cv.py | 1 + tests/main/test_wam.py | 138 ++++++++++++++++++++-------------- 6 files changed, 152 insertions(+), 101 deletions(-) create mode 100644 pyfv3/stencils/gravity.py diff --git a/pyfv3/dycore_state.py b/pyfv3/dycore_state.py index 79c0fef2..01e546b8 100644 --- a/pyfv3/dycore_state.py +++ b/pyfv3/dycore_state.py @@ -312,7 +312,7 @@ class DycoreState: rdg_var: Quantity = field( metadata={ "name": "gas constant for dry air over variable gravity (RDGAS / grav_var) for Whole Atmosphere calculations", - "units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units? + "units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units? "dims": [X_DIM, Y_DIM, Z_DIM], "intent": "inout", } diff --git a/pyfv3/stencils/dyn_core.py b/pyfv3/stencils/dyn_core.py index ac62eea4..5a423999 100644 --- a/pyfv3/stencils/dyn_core.py +++ b/pyfv3/stencils/dyn_core.py @@ -6,6 +6,7 @@ import ndsl.constants as constants import ndsl.stencils.basic_operations as basic import pyfv3.stencils.d_sw as d_sw +import pyfv3.stencils.gravity as gravity import pyfv3.stencils.nh_p_grad as nh_p_grad import pyfv3.stencils.pe_halo as pe_halo import pyfv3.stencils.ray_fast as ray_fast @@ -91,16 +92,6 @@ def zero_data( diss_estd = 0.0 -def average_gravity(grav_var: FloatField, grav_var_h: FloatField): - """ - Args: - grav_var (out): gravity field - grav_var_h (in): gravity value at height - """ - with computation(FORWARD), interval(...): - grav_var[0, 0, 0] = 0.5*(grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) - - def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): """ # JK TODO: Is there a better name than this? @@ -109,7 +100,7 @@ def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): grav_var (in): variable gravity """ with computation(FORWARD), interval(...): - rdg = - constants.RDGAS / grav_var + rdg = -constants.RDGAS / grav_var def gz_from_surface_height_and_thicknesses( @@ -677,10 +668,11 @@ def __init__( ) self._average_gravity = stencil_factory.from_origin_domain( - average_gravity, + gravity.average_gravity_stencil_defn, origin=grid_indexing.origin_full(), domain=grid_indexing.domain_full(), ) + self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain( neg_rdgas_div_gravity, origin=grid_indexing.origin_full(), diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index 0cf24fb6..da5a32e9 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -4,14 +4,15 @@ from dace.frontend.python.interface import nounroll as dace_no_unroll import ndsl.dsl.gt4py_utils as utils +import pyfv3.stencils.gravity as gravity import pyfv3.stencils.moist_cv as moist_cv from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater from ndsl.checkpointer import NullCheckpointer from ndsl.comm.mpi import MPI -from ndsl.constants import KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR, GRAV, RADIUS +from ndsl.constants import GRAV, KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR from ndsl.dsl.dace.orchestration import dace_inhibitor, orchestrate -from ndsl.dsl.gt4py import FORWARD, BACKWARD, PARALLEL, computation, interval -from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ +from ndsl.dsl.gt4py import PARALLEL, computation, interval +from ndsl.dsl.typing import Float, FloatField from ndsl.grid import DampingCoefficients, GridData from ndsl.logging import ndsl_log from ndsl.performance import NullTimer, Timer @@ -75,29 +76,6 @@ def fvdyn_temporaries( tmps[name] = quantity return tmps -def adjust_gravity( - grav_var: FloatField, - grav_var_h: FloatField, - phis: FloatFieldIJ, - delz: FloatField, -): - """ - Args: - grav_var (out): gravity field - grav_var_h (out): height based gravity - phis (out): - delz (out): - """ - with computation(FORWARD), interval(-1,None): - newrad = RADIUS + (phis/GRAV) - grav_var_h = GRAV*(RADIUS**2)/newrad**2 - - with computation(BACKWARD), interval(0,-1): - newrad = RADIUS + (phis/GRAV) - newrad = newrad - delz - grav_var_h = GRAV*(RADIUS**2)/newrad**2 - grav_var = 0.5*(grav_var_h[0, 0, 1] + grav_var_h[0, 0, 0]) - @dace_inhibitor def log_on_rank_0(msg: str): @@ -288,12 +266,12 @@ def __init__( self._init_gravity = stencil_factory.from_origin_domain( set_value, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(add=(0,0,1)), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self._adjust_gravity = stencil_factory.from_origin_domain( - adjust_gravity, + gravity.adjust_gravity, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(add=(0,0,1)), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self.acoustic_dynamics = AcousticDynamics( comm=comm, @@ -360,7 +338,10 @@ def __init__( comm.get_scalar_halo_updater([full_xyz_spec]), state, ["omga"], comm=comm ) self._gravity_halo_updater = WrappedHaloUpdater( - comm.get_scalar_halo_updater([full_xyz_spec]), state, ["grav_var"], comm=comm + comm.get_scalar_halo_updater([full_xyz_spec]), + state, + ["grav_var"], + comm=comm, ) self._n_split = config.n_split self._k_split = config.k_split @@ -517,12 +498,14 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool): self._dp_initial, ) - # self._init_gravity(state.grav_var, state.grav_var_h) self._init_gravity(state.grav_var, GRAV) self._init_gravity(state.grav_var_h, GRAV) if self.config.enable_wam: - self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz) + self._adjust_gravity( + state.grav_var, state.grav_var_h, state.phis, state.delz + ) + self._gravity_halo_updater.update() if self._conserve_total_energy > 0: raise NotImplementedError( @@ -649,7 +632,9 @@ def _compute(self, state: DycoreState, timer: Timer): ) self._checkpoint_remapping_out(state) if self.config.enable_wam: - self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz) + self._adjust_gravity( + state.grav_var, state.grav_var_h, state.phis, state.delz + ) # TODO: can we pull this block out of the loop intead of # using an if-statement? if last_step: diff --git a/pyfv3/stencils/gravity.py b/pyfv3/stencils/gravity.py new file mode 100644 index 00000000..a42e71e9 --- /dev/null +++ b/pyfv3/stencils/gravity.py @@ -0,0 +1,45 @@ +from ndsl.constants import GRAV, RADIUS +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation +from ndsl.dsl.gt4py import function as gtfunction +from ndsl.dsl.gt4py import interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ + + +@gtfunction +def average_gravity(grav_var: FloatField, grav_var_h: FloatField): + grav_var = 0.5 * (grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) + return grav_var + + +def average_gravity_stencil_defn(grav_var: FloatField, grav_var_h: FloatField): + """ + Args: + grav_var (out): gravity field + grav_var_h (in): gravity value at interfaces + """ + with computation(FORWARD), interval(...): + grav_var = average_gravity(grav_var, grav_var_h) + + +def adjust_gravity( + grav_var: FloatField, + grav_var_h: FloatField, + phis: FloatFieldIJ, + delz: FloatField, +): + """ + Args: + grav_var (out): gravity field + grav_var_h (out): gravity value at interfaces + phis (in): geopotential + delz (in): change in vertical height + """ + with computation(FORWARD), interval(-1, None): + newrad = RADIUS + (phis / GRAV) + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + + with computation(BACKWARD), interval(0, -1): + newrad = RADIUS + (phis / GRAV) + newrad = newrad - delz + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + grav_var = average_gravity(grav_var, grav_var_h) diff --git a/pyfv3/stencils/moist_cv.py b/pyfv3/stencils/moist_cv.py index a0464734..7eef7152 100644 --- a/pyfv3/stencils/moist_cv.py +++ b/pyfv3/stencils/moist_cv.py @@ -122,6 +122,7 @@ def compute_pkz_func(delp, delz, pt, cappa, rdg_var): # TODO use the exponential form for closer answer matching return exp(cappa * log(rdg_var * delp / delz * pt)) + def moist_pkz( qvapor: FloatField, qliquid: FloatField, diff --git a/tests/main/test_wam.py b/tests/main/test_wam.py index 582f7947..030c6f72 100644 --- a/tests/main/test_wam.py +++ b/tests/main/test_wam.py @@ -1,8 +1,10 @@ -from pathlib import Path -from dataclasses import field +import copy from datetime import timedelta +from pathlib import Path from typing import Tuple -import copy + +# use numpy for now until I figure out how to use FloatField +import numpy as np # JK TODO: Should I be using xumpy? import pyfv3.initialization.analytic_init as ai from ndsl import ( @@ -19,27 +21,23 @@ SubtileGridSizer, TilePartitioner, ) -from ndsl.grid import DampingCoefficients, GridData, MetricTerms -from ndsl.dsl.typing import Float, FloatField -from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM +from ndsl.constants import GRAV, RADIUS, RDGAS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.gt4py import stencil +from ndsl.dsl.typing import Float +from ndsl.grid import DampingCoefficients, GridData, MetricTerms from ndsl.stencils.basic_operations import set_value -from pyfv3 import DynamicalCore, DynamicalCoreConfig, DycoreState -from pyfv3.initialization import init_utils +from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig from pyfv3.initialization.analytic_init import AnalyticCase -from pyfv3.stencils.dyn_core import AcousticDynamics -from pyfv3.stencils.fv_dynamics import adjust_gravity -from pyfv3.stencils.dyn_core import average_gravity, compute_geopotential +from pyfv3.stencils.dyn_core import AcousticDynamics, compute_geopotential +from pyfv3.stencils.gravity import adjust_gravity, average_gravity_stencil_defn -# use numpy for now until I figure out how to use FloatField -import numpy as np # JK TODO: Should I be using xumpy? -# JK NOTE TODO: Just sticking things in here for now, will distribute them into their right -# places in the future. +# JK NOTE TODO: Just sticking things in here for now, +# will distribute them into their right places in the future. def test_enable_wam() -> None: - # Set up dycore config with enable_wam = True + # Set up dycore config with enable_wam = True # Set up dycore # Check that gravity is variable and grav_var and grav_var_h are used somehow? # maybe call to compute_geopotential? @@ -51,8 +49,8 @@ def test_enable_wam() -> None: # TODO assert that something is different between wam_enabled = False # JK NOTE: I have to find out what to test..... - - assert False # TODO + + assert False # TODO ############################ dyncore_state.py @@ -117,7 +115,7 @@ def setup_dycore_state() -> DycoreState: tile_rank=communicator.tile.rank, ) quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) - #eta_file = Path("tests/data/eta79.nc") + # eta_file = Path("tests/data/eta79.nc") eta_file = Path(__file__).resolve().parents[1] / "data" / "eta79.nc" metric_terms = MetricTerms( quantity_factory=quantity_factory, @@ -155,7 +153,7 @@ def test_neg_rdgas_div_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -177,13 +175,15 @@ def test_neg_rdgas_div_gravity() -> None: init_gravity_stencil(grav_var, GRAV) expected_grav_var_np = copy.deepcopy(grav_var.field[:]) - neg_rdgas_div_gravity_stencil = stencil(backend=example_backend, definition=neg_rdgas_div_gravity) + neg_rdgas_div_gravity_stencil = stencil( + backend=example_backend, definition=neg_rdgas_div_gravity + ) neg_rdgas_div_gravity_stencil(rdg, grav_var) # grav_var_h should be unchanged by the stencil assert np.array_equal(grav_var.field[:], expected_grav_var_np) - expected_rdg_np = - (RDGAS / expected_grav_var_np[:]) + expected_rdg_np = -(RDGAS / expected_grav_var_np[:]) assert np.array_equal(rdg.field[:], expected_rdg_np) @@ -194,7 +194,7 @@ def test_average_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -204,7 +204,7 @@ def test_average_gravity() -> None: backend=example_backend, ) - grav_var_h_np = np.random.random((nx, ny, nz+1)) + grav_var_h_np = np.random.random((nx, ny, nz + 1)) expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) grav_var_h = Quantity( data=grav_var_h_np, @@ -213,14 +213,18 @@ def test_average_gravity() -> None: number_of_halo_points=n_halos, backend=example_backend, ) - - average_gravity_numpy = stencil(backend=example_backend, definition=average_gravity) + + average_gravity_numpy = stencil( + backend=example_backend, definition=average_gravity_stencil_defn + ) average_gravity_numpy(grav_var, grav_var_h) # grav_var_h should be unchanged by the stencil assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) - expected_grav_var_np = (expected_grav_var_h_np[:,:,:-1]+expected_grav_var_h_np[:,:,1:]) / 2 + expected_grav_var_np = ( + expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] + ) / 2 assert np.array_equal(grav_var.field[:], expected_grav_var_np) @@ -231,9 +235,9 @@ def test_compute_geopotential() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" - grav_var_h_np = np.random.random((nx, ny, nz+1)) + grav_var_h_np = np.random.random((nx, ny, nz + 1)) expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) grav_var_h = Quantity( data=grav_var_h_np, @@ -263,7 +267,9 @@ def test_compute_geopotential() -> None: backend=example_backend, ) - compute_geopotential_np = stencil(backend=example_backend, definition=compute_geopotential) + compute_geopotential_np = stencil( + backend=example_backend, definition=compute_geopotential + ) compute_geopotential_np(zh, gz, grav_var_h) # Check that zh and grav_var_h are unchanged @@ -273,7 +279,7 @@ def test_compute_geopotential() -> None: # Check that gz = zh * grav_var_h assert not np.array_equal(gz.field[:], gz_np_copy) # JK TODO: Is the expected_gz_np calculated correctly? double check with fortran or frank... - expected_gz_np = zh_np * grav_var_h_np[:,:,:-1] + expected_gz_np = zh_np * grav_var_h_np[:, :, :-1] assert np.array_equal(gz.field[:], expected_gz_np) @@ -388,6 +394,7 @@ def setup_acoustic_dynamics(npx, npy, n_halo) -> Tuple[AcousticDynamics, DycoreS # JK TODO simplify this please, if possible... return dycore.acoustic_dynamics, state + def test_acoustic_dynamics_init_average_gravity() -> None: # Check that average gravity is called/used in AcousticDynamics initialization @@ -397,12 +404,14 @@ def test_acoustic_dynamics_init_average_gravity() -> None: n_halo = 3 # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, _ = setup_acoustic_dynamics(nx-(2*n_halo)+1, ny-(2*n_halo)+1, n_halo) # nz is hard-coded to 79 + ac_dyn, _ = setup_acoustic_dynamics( + nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo + ) # nz is hard-coded to 79 # JK TODO: switch from example grav_var, grav_var_h to state.grav_var, state.grav_var_h example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -412,7 +421,7 @@ def test_acoustic_dynamics_init_average_gravity() -> None: backend=example_backend, ) - grav_var_h_np = np.random.random((nx, ny, nz+1)) + grav_var_h_np = np.random.random((nx, ny, nz + 1)) expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) grav_var_h = Quantity( data=grav_var_h_np, @@ -428,7 +437,9 @@ def test_acoustic_dynamics_init_average_gravity() -> None: # grav_var_h should be unchanged by the stencil assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) - expected_grav_var_np = (expected_grav_var_h_np[:,:,:-1]+expected_grav_var_h_np[:,:,1:]) / 2 + expected_grav_var_np = ( + expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] + ) / 2 assert np.array_equal(grav_var.field[:], expected_grav_var_np) @@ -438,28 +449,33 @@ def test_acoustic_dynamics_call_average_gravity() -> None: ny = 12 nz = 79 n_halo = 3 - timestep = 225 # JK TODO: Is this right? + timestep = 225 # JK TODO: Is this right? # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, state = setup_acoustic_dynamics(nx-(2*n_halo)+1, ny-(2*n_halo)+1, n_halo) # nz is hard-coded to 79 + ac_dyn, state = setup_acoustic_dynamics( + nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo + ) # nz is hard-coded to 79 init_grav_var_np = copy.deepcopy(state.grav_var.field) init_grav_var_h_np = copy.deepcopy(state.grav_var_h.field) ac_dyn(state, timestep) - + # The state.grav_var_h should be unchanged by the stencil. assert np.array_equal(state.grav_var_h.field[:], init_grav_var_h_np) # Check that the state.grav_var values match expectation: - expected_grav_var_np = (init_grav_var_h_np[:,:,:-1]+init_grav_var_h_np[:,:,1:]) / 2 + expected_grav_var_np = ( + init_grav_var_h_np[:, :, :-1] + init_grav_var_h_np[:, :, 1:] + ) / 2 assert np.array_equal(state.grav_var.field[:], expected_grav_var_np) + """ E + where False = ( array([ -[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), +[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), array([[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]])) @@ -469,17 +485,22 @@ def test_acoustic_dynamics_call_average_gravity() -> None: ############################ fv_dynamics.py + def test_init_gravity() -> None: # Check that init_gravity sets 3d grav_var to the constant GRAV for all vals backend = "numpy" nx_tile, ny_tile, nz, n_halo = 6, 6, 2, 3 - layout = (1,1) + layout = (1, 1) partitioner = CubedSpherePartitioner(TilePartitioner(layout)) mpi_comm = NullComm(rank=0, total_ranks=6, fill_value=0.0) communicator = CubedSphereCommunicator(mpi_comm, partitioner) - compilation_config = CompilationConfig(backend=backend, rebuild=False, validate_args=True) + compilation_config = CompilationConfig( + backend=backend, rebuild=False, validate_args=True + ) dace_config = DaceConfig(communicator=communicator, backend=backend) - stencil_config = StencilConfig(compilation_config=compilation_config, dace_config=dace_config) + stencil_config = StencilConfig( + compilation_config=compilation_config, dace_config=dace_config + ) sizer = SubtileGridSizer.from_tile_params( nx_tile=nx_tile, ny_tile=ny_tile, @@ -487,15 +508,17 @@ def test_init_gravity() -> None: n_halo=n_halo, layout=layout, tile_partitioner=partitioner.tile, - tile_rank=communicator.tile.rank + tile_rank=communicator.tile.rank, + ) + grid_indexing = GridIndexing.from_sizer_and_communicator( + sizer=sizer, comm=communicator ) - grid_indexing = GridIndexing.from_sizer_and_communicator(sizer=sizer, comm=communicator) stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend) init_gravity_stencil = stencil_factory.from_origin_domain( set_value, origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(add=(0,0,1)), + domain=grid_indexing.domain_full(add=(0, 0, 1)), ) grav_var: Quantity = quantity_factory.zeros( [X_DIM, Y_DIM, Z_DIM], @@ -522,7 +545,7 @@ def test_adjust_gravity() -> None: n_halos = 3 example_dims = ["I", "J", "K"] - example_backend="numpy" + example_backend = "numpy" grav_var = Quantity( data=np.zeros((nx, ny, nz)), @@ -533,7 +556,7 @@ def test_adjust_gravity() -> None: ) grav_var_h = Quantity( - data=np.zeros((nx, ny, nz+1)), + data=np.zeros((nx, ny, nz + 1)), dims=example_dims, units="grav_var_h units", number_of_halo_points=n_halos, @@ -566,24 +589,29 @@ def test_adjust_gravity() -> None: # Check that grav_var and grav_var_h are set appropriately newrad = np.zeros((nx, ny)) expected_grav_var_np = np.zeros((nx, ny, nz)) - expected_grav_var_h_np = np.zeros((nx, ny, nz+1)) - + expected_grav_var_h_np = np.zeros((nx, ny, nz + 1)) + assert np.array_equal(grav_var_h.field[:].shape, expected_grav_var_h_np.shape) for j in range(ny): for i in range(nx): for k in range(nz, -1, -1): if k == nz: - newrad[i,j] = RADIUS + (phis.field[i,j]/GRAV) + newrad[i, j] = RADIUS + (phis.field[i, j] / GRAV) else: - newrad[i,j] = newrad[i,j] - delz.field[i,j,k] - expected_grav_var_h_np[i,j,k] = GRAV*((RADIUS**2)/(newrad[i,j]**2)) + newrad[i, j] = newrad[i, j] - delz.field[i, j, k] + expected_grav_var_h_np[i, j, k] = GRAV * ( + (RADIUS**2) / (newrad[i, j] ** 2) + ) if k < nz: - expected_grav_var_np[i,j,k] = 0.5*(expected_grav_var_h_np[i,j,k+1]+expected_grav_var_h_np[i,j,k]) + expected_grav_var_np[i, j, k] = 0.5 * ( + expected_grav_var_h_np[i, j, k + 1] + + expected_grav_var_h_np[i, j, k] + ) # JK TODO: is there some rtol/atol threshold? the np vs stencil calculations are close but not exact. assert np.allclose(grav_var_h.field[:], expected_grav_var_h_np) assert np.allclose(grav_var.field[:], expected_grav_var_np) - + # TODO JK NOTE to self --- checkout log_on_rank_0 for values that might be useful for test (possibly) From 1ae997d456c3a95958bb4096f611662caa0f6680 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 8 Jan 2026 13:52:09 -0500 Subject: [PATCH 5/7] Added in use of variable gravity for Riem_Solver3, Riem_Solver_C, and DryConvectiveAdjustment --- pyfv3/stencils/dyn_core.py | 2 ++ pyfv3/stencils/fv_subgridz.py | 8 ++++---- pyfv3/stencils/riem_solver3.py | 10 +++++++++- pyfv3/stencils/riem_solver_c.py | 16 +++++++++++++--- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/pyfv3/stencils/dyn_core.py b/pyfv3/stencils/dyn_core.py index 5a423999..92867aa0 100644 --- a/pyfv3/stencils/dyn_core.py +++ b/pyfv3/stencils/dyn_core.py @@ -894,6 +894,7 @@ def __call__( self._gz, self._pkc, state.omga, + state.grav_var, ) self._p_grad_c( @@ -981,6 +982,7 @@ def __call__( state.pk, state.peln, state.w, + state.grav_var, ) self._halo_updaters.zh.start() diff --git a/pyfv3/stencils/fv_subgridz.py b/pyfv3/stencils/fv_subgridz.py index 9090da16..6d6665b4 100644 --- a/pyfv3/stencils/fv_subgridz.py +++ b/pyfv3/stencils/fv_subgridz.py @@ -10,7 +10,6 @@ CP_VAP, CV_AIR, CV_VAP, - GRAV, RDGAS, X_DIM, Y_DIM, @@ -28,7 +27,6 @@ from gt4py.cartesian.gtscript import __INLINED # isort:skip RK = CP_AIR / RDGAS + 1.0 -G2 = 0.5 * GRAV T1_MIN = 160.0 T2_MIN = 165.0 T2_MAX = 315.0 @@ -95,6 +93,7 @@ def init( qo3mr: FloatField, qsgs_tke: FloatField, qcld: FloatField, + grav_var: FloatField, ): with computation(PARALLEL), interval(...): t0 = ta @@ -117,11 +116,11 @@ def init( cpm, cvm = standard_cm( cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel ) - gz = gzh[0, 0, 1] - G2 * delz + gz = gzh[0, 0, 1] - grav_var * delz tmp = tvol(gz, u0, v0, w0) static_energy = cpm * t0 + tmp total_energy = cvm * t0 + tmp - gzh = gzh[0, 0, 1] - GRAV * delz + gzh = gzh[0, 0, 1] - grav_var * delz @gtfunction @@ -903,6 +902,7 @@ def __call__( state.qo3mr, state.qsgs_tke, state.qcld, + state.grav_var, ) for n in range(self._m): diff --git a/pyfv3/stencils/riem_solver3.py b/pyfv3/stencils/riem_solver3.py index 789904bf..38af40d5 100644 --- a/pyfv3/stencils/riem_solver3.py +++ b/pyfv3/stencils/riem_solver3.py @@ -31,6 +31,7 @@ def precompute( ptop: Float, peln1: Float, ptk: Float, + grav_var: FloatField, ): """ Args: @@ -48,6 +49,10 @@ def precompute( dz (out): p_gas (out): pressure defined at vertical mid levels due to gas-phase only, excluding condensates (Pa) + ptop (in): + peln1 (in): + ptk (in): + grav_var (in): variable gravity """ with computation(PARALLEL), interval(...): delta_mass = delp @@ -73,7 +78,7 @@ def precompute( pk3 = exp(constants.KAPPA * log_p_interface) with computation(PARALLEL), interval(...): gamma = 1.0 / (1.0 - cappa) # gamma, cp/cv - delta_mass = delta_mass * constants.RGRAV + delta_mass = delta_mass / grav_var with computation(PARALLEL), interval(0, -1): p_gas = (p_interface_gas[0, 0, 1] - p_interface_gas) / ( log_p_interface_gas[0, 0, 1] - log_p_interface_gas @@ -232,6 +237,7 @@ def __call__( pk: FloatField, log_p_interface: FloatField, w: FloatFieldIJ, + grav_var: FloatField, ): """ Solves for the nonhydrostatic terms for vertical velocity (w) @@ -262,6 +268,7 @@ def __call__( log_p_interface (out): logarithm of interface pressure, only written if last_call=True w (inout): vertical velocity + grav_var (in): variable gravity """ # TODO: propagate variable renaming for these into stencils here and @@ -298,6 +305,7 @@ def __call__( ptop, peln1, ptk, + grav_var, ) self._sim1_solve( diff --git a/pyfv3/stencils/riem_solver_c.py b/pyfv3/stencils/riem_solver_c.py index b2170a56..278935e3 100644 --- a/pyfv3/stencils/riem_solver_c.py +++ b/pyfv3/stencils/riem_solver_c.py @@ -22,6 +22,7 @@ def precompute( gm: FloatField, pm: FloatField, ptop: Float, + grav_var: FloatField, ): """ Args: @@ -40,6 +41,8 @@ def precompute( pm (out): hydrostatic cell mean pressure, derivation in documentation (Chapter 4? 7?) TODO: identify chapter reference, will be sent by Lucas + ptop (in): + grav_var (in): variable gravity """ with computation(PARALLEL), interval(...): dm = delpc @@ -57,7 +60,7 @@ def precompute( dz = gz[0, 0, 1] - gz with computation(PARALLEL), interval(...): gm = 1.0 / (1.0 - cappa) - dm /= constants.GRAV + dm /= grav_var with computation(PARALLEL), interval(0, -1): # (1) From \partial p*/\partial z = -\rho g, we can separate and integrate # over a layer to get @@ -87,6 +90,7 @@ def finalize( pef: FloatField, gz: FloatField, ptop: Float, + grav_var: FloatField, ): """ Enforce vertical boundary conditions. @@ -101,6 +105,8 @@ def finalize( dz (in): pef (out): gz (out): + ptop (in): + grav_var (in): variable gravity """ with computation(PARALLEL): with interval(0, 1): @@ -111,7 +117,7 @@ def finalize( with interval(-1, None): gz = hs with interval(0, -1): - gz = gz[0, 0, 1] - dz * constants.GRAV + gz = gz[0, 0, 1] - dz * grav_var class NonhydrostaticVerticalSolverCGrid: @@ -201,6 +207,7 @@ def __call__( gz: FloatField, pef: FloatField, w3: FloatField, + grav_var: FloatField, ): """ Solves for the nonhydrostatic terms for vertical velocity (w) @@ -219,6 +226,7 @@ def __call__( gz (inout): geopotential height pef (out): full hydrostatic pressure w3 (in): vertical velocity + grav_var (in): variable gravity """ # TODO: integrate these notes into comments/code, double-check: @@ -266,4 +274,6 @@ def __call__( ws, ) # pe is nonhydrostatic perturbation pressure defined on interfaces - self._finalize_stencil(self._pe, self._pem, hs, self._dz, pef, gz, ptop) + self._finalize_stencil( + self._pe, self._pem, hs, self._dz, pef, gz, ptop, grav_var + ) From aba07739d8dfa6db979bd158d9718c82b00fd44b Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 13 Jan 2026 15:54:48 -0500 Subject: [PATCH 6/7] Centralized wam methods into one module, introduced potential wam stencil class --- pyfv3/stencils/dyn_core.py | 40 ++++------ pyfv3/stencils/fv_dynamics.py | 7 +- pyfv3/stencils/gravity.py | 45 ----------- pyfv3/stencils/rdg_adjust.py | 17 ----- pyfv3/stencils/wam.py | 79 ++++++++++++++++++++ tests/main/test_wam.py | 137 +++++++++++++++++----------------- 6 files changed, 165 insertions(+), 160 deletions(-) delete mode 100644 pyfv3/stencils/gravity.py delete mode 100644 pyfv3/stencils/rdg_adjust.py create mode 100644 pyfv3/stencils/wam.py diff --git a/pyfv3/stencils/dyn_core.py b/pyfv3/stencils/dyn_core.py index 07e565c9..4ace486d 100644 --- a/pyfv3/stencils/dyn_core.py +++ b/pyfv3/stencils/dyn_core.py @@ -6,12 +6,12 @@ import ndsl.constants as constants import ndsl.stencils.basic_operations as basic import pyfv3.stencils.d_sw as d_sw -import pyfv3.stencils.gravity as gravity +import pyfv3.stencils.wam as wam import pyfv3.stencils.nh_p_grad as nh_p_grad import pyfv3.stencils.pe_halo as pe_halo import pyfv3.stencils.ray_fast as ray_fast import pyfv3.stencils.temperature_adjust as temperature_adjust -import pyfv3.stencils.rdg_adjust as rdg_adjust +# import pyfv3.stencils.rdg_adjust as rdg_adjust import pyfv3.stencils.updatedzc as updatedzc import pyfv3.stencils.updatedzd as updatedzd from ndsl import ( @@ -92,17 +92,6 @@ def zero_data( heat_source = 0.0 diss_estd = 0.0 - -def average_gravity(grav_var: FloatField, grav_var_h: FloatField): - """ - Args: - grav_var (out): gravity field - grav_var_h (in): gravity value at height - """ - with computation(FORWARD), interval(...): - grav_var[0, 0, 0] = 0.5*(grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) - - def gz_from_surface_height_and_thicknesses( zs: FloatFieldIJ, delz: FloatField, gz: FloatField ): @@ -667,17 +656,17 @@ def __init__( grav_var_h=state.grav_var_h, ) - self._average_gravity = stencil_factory.from_origin_domain( - gravity.average_gravity_stencil_defn, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) + # self._average_gravity = stencil_factory.from_origin_domain( + # wam.average_gravity_stencil_defn, + # origin=grid_indexing.origin_full(), + # domain=grid_indexing.domain_full(), + # ) - self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain( - rdg_adjust.neg_rdgas_div_gravity, - origin=grid_indexing.origin_full(), - domain=grid_indexing.domain_full(), - ) + # self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain( + # wam.neg_rdgas_div_gravity, + # origin=grid_indexing.origin_full(), + # domain=grid_indexing.domain_full(), + # ) # See divergence_damping.py, _get_da_min for explanation of this function @dace_inhibitor @@ -817,8 +806,9 @@ def __call__( if it == 0: self._halo_updaters.delp__pt.wait() self._halo_updaters.grav_var_h.update() - self._average_gravity(state.grav_var, state.grav_var_h) - self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var) + # ALREADY HAPPENING IN DYNAMICAL CORE CALL + # self._average_gravity(state.grav_var, state.grav_var_h) + # self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var) if it == n_split - 1 and end_step: if self.config.use_old_omega: diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index d41539ac..355a2ffa 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -4,7 +4,7 @@ from dace.frontend.python.interface import nounroll as dace_no_unroll import ndsl.dsl.gt4py_utils as utils -import pyfv3.stencils.gravity as gravity +import pyfv3.stencils.wam as wam import pyfv3.stencils.moist_cv as moist_cv from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater from ndsl.checkpointer import NullCheckpointer @@ -26,7 +26,6 @@ from pyfv3.stencils.dyn_core import AcousticDynamics from pyfv3.stencils.neg_adj3 import AdjustNegativeTracerMixingRatio from pyfv3.stencils.remapping import LagrangianToEulerian -import pyfv3.stencils.rdg_adjust as rdg_adjust def pt_to_potential_density_pt( @@ -270,12 +269,12 @@ def __init__( domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self._adjust_gravity = stencil_factory.from_origin_domain( - gravity.adjust_gravity, + wam.adjust_gravity, origin=grid_indexing.origin_full(), domain=grid_indexing.domain_full(add=(0, 0, 1)), ) self._adjust_rdg = stencil_factory.from_origin_domain( - rdg_adjust.neg_rdgas_div_gravity, + wam.neg_rdgas_div_gravity, origin=grid_indexing.origin_full(), domain=grid_indexing.domain_full(), ) diff --git a/pyfv3/stencils/gravity.py b/pyfv3/stencils/gravity.py deleted file mode 100644 index a42e71e9..00000000 --- a/pyfv3/stencils/gravity.py +++ /dev/null @@ -1,45 +0,0 @@ -from ndsl.constants import GRAV, RADIUS -from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation -from ndsl.dsl.gt4py import function as gtfunction -from ndsl.dsl.gt4py import interval -from ndsl.dsl.typing import FloatField, FloatFieldIJ - - -@gtfunction -def average_gravity(grav_var: FloatField, grav_var_h: FloatField): - grav_var = 0.5 * (grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) - return grav_var - - -def average_gravity_stencil_defn(grav_var: FloatField, grav_var_h: FloatField): - """ - Args: - grav_var (out): gravity field - grav_var_h (in): gravity value at interfaces - """ - with computation(FORWARD), interval(...): - grav_var = average_gravity(grav_var, grav_var_h) - - -def adjust_gravity( - grav_var: FloatField, - grav_var_h: FloatField, - phis: FloatFieldIJ, - delz: FloatField, -): - """ - Args: - grav_var (out): gravity field - grav_var_h (out): gravity value at interfaces - phis (in): geopotential - delz (in): change in vertical height - """ - with computation(FORWARD), interval(-1, None): - newrad = RADIUS + (phis / GRAV) - grav_var_h = GRAV * (RADIUS**2) / newrad**2 - - with computation(BACKWARD), interval(0, -1): - newrad = RADIUS + (phis / GRAV) - newrad = newrad - delz - grav_var_h = GRAV * (RADIUS**2) / newrad**2 - grav_var = average_gravity(grav_var, grav_var_h) diff --git a/pyfv3/stencils/rdg_adjust.py b/pyfv3/stencils/rdg_adjust.py deleted file mode 100644 index d891fe1f..00000000 --- a/pyfv3/stencils/rdg_adjust.py +++ /dev/null @@ -1,17 +0,0 @@ -import ndsl.constants as constants -from ndsl.dsl.gt4py import FORWARD, computation, interval -from ndsl.dsl.typing import FloatField - - -def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): - """ - # JK TODO: Is there a better name than this? - Adjust rdg to be the negative RDGAS divided by the variable gravity - for Whole Atmosphere Modeling - - Args: - rdg (out): negative radiative gas divided by variable gravity - grav_var (in): variable gravity - """ - with computation(FORWARD), interval(...): - rdg = - constants.RDGAS / grav_var diff --git a/pyfv3/stencils/wam.py b/pyfv3/stencils/wam.py new file mode 100644 index 00000000..52ef7c78 --- /dev/null +++ b/pyfv3/stencils/wam.py @@ -0,0 +1,79 @@ +from ndsl.constants import GRAV, RADIUS, RDGAS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM +from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation +from ndsl.dsl.gt4py import function as gtfunction +from ndsl.dsl.gt4py import interval +from ndsl.dsl.typing import FloatField, FloatFieldIJ +from ndsl import StencilFactory + + +@gtfunction +def average_gravity(grav_var: FloatField, grav_var_h: FloatField): + grav_var = 0.5 * (grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1]) + return grav_var + + +# May not need this stencil at all +def average_gravity_stencil_defn(grav_var: FloatField, grav_var_h: FloatField): + """ + Args: + grav_var (out): gravity field + grav_var_h (in): gravity value at interfaces + """ + with computation(FORWARD), interval(...): + grav_var = average_gravity(grav_var, grav_var_h) + + +def adjust_gravity( + grav_var: FloatField, + grav_var_h: FloatField, + phis: FloatFieldIJ, + delz: FloatField, +): + """ + Args: + grav_var (out): gravity field + grav_var_h (out): gravity value at interfaces + phis (in): geopotential + delz (in): change in vertical height + """ + with computation(FORWARD), interval(-1, None): + newrad = RADIUS + (phis / GRAV) + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + + with computation(BACKWARD), interval(0, -1): + newrad = RADIUS + (phis / GRAV) + newrad = newrad - delz + grav_var_h = GRAV * (RADIUS**2) / newrad**2 + grav_var = average_gravity(grav_var, grav_var_h) + +def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): + """ + # JK TODO: Is there a better name than this? + Adjust rdg to be the negative RDGAS divided by the variable gravity + for Whole Atmosphere Modeling + + Args: + rdg (out): negative radiative gas divided by variable gravity + grav_var (in): variable gravity + """ + with computation(FORWARD), interval(...): + rdg = - RDGAS / grav_var + +class WholeAtmos: + def __init__(self, stencil_factory: StencilFactory): + self.constructed_average_gravity_stencil = stencil_factory.from_dims_halo( + func=average_gravity_stencil_defn, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.constructed_adjust_gravity_stencil = stencil_factory.from_dims_halo( + func=adjust_gravity, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.constructed_neg_rdgas_div_gravity_stencil = stencil_factory.from_dims_halo( + func=neg_rdgas_div_gravity, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, grav_var: FloatField, grav_var_h: FloatField, rdg: FloatField, phis: FloatFieldIJ, delz: FloatField): + self.constructed_adjust_gravity_stencil(grav_var, grav_var_h, phis, delz) + self.constructed_neg_rdgas_div_gravity_stencil(rdg, grav_var) \ No newline at end of file diff --git a/tests/main/test_wam.py b/tests/main/test_wam.py index b4207303..75ab447e 100644 --- a/tests/main/test_wam.py +++ b/tests/main/test_wam.py @@ -29,9 +29,8 @@ from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig from pyfv3.initialization.analytic_init import AnalyticCase from pyfv3.stencils.dyn_core import AcousticDynamics -from pyfv3.stencils.rdg_adjust import neg_rdgas_div_gravity -from pyfv3.stencils.fv_dynamics import adjust_gravity, init_gravity, init_gravity_h -from pyfv3.stencils.dyn_core import average_gravity, compute_geopotential +from pyfv3.stencils.wam import adjust_gravity, average_gravity_stencil_defn, neg_rdgas_div_gravity +from pyfv3.stencils.dyn_core import compute_geopotential # JK NOTE TODO: Just sticking things in here for now, @@ -397,95 +396,95 @@ def setup_acoustic_dynamics(npx, npy, n_halo) -> Tuple[AcousticDynamics, DycoreS return dycore.acoustic_dynamics, state -def test_acoustic_dynamics_init_average_gravity() -> None: - # Check that average gravity is called/used in AcousticDynamics initialization +# def test_acoustic_dynamics_init_average_gravity() -> None: +# # Check that average gravity is called/used in AcousticDynamics initialization - nx = 12 - ny = 12 - nz = 79 - n_halo = 3 +# nx = 12 +# ny = 12 +# nz = 79 +# n_halo = 3 - # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, _ = setup_acoustic_dynamics( - nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo - ) # nz is hard-coded to 79 +# # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 +# ac_dyn, _ = setup_acoustic_dynamics( +# nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo +# ) # nz is hard-coded to 79 - # JK TODO: switch from example grav_var, grav_var_h to state.grav_var, state.grav_var_h +# # JK TODO: switch from example grav_var, grav_var_h to state.grav_var, state.grav_var_h - example_dims = ["I", "J", "K"] - example_backend = "numpy" +# example_dims = ["I", "J", "K"] +# example_backend = "numpy" - grav_var = Quantity( - data=np.zeros((nx, ny, nz)), - dims=example_dims, - units="grav_var units", - number_of_halo_points=n_halo, - backend=example_backend, - ) +# grav_var = Quantity( +# data=np.zeros((nx, ny, nz)), +# dims=example_dims, +# units="grav_var units", +# number_of_halo_points=n_halo, +# backend=example_backend, +# ) - grav_var_h_np = np.random.random((nx, ny, nz + 1)) - expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) - grav_var_h = Quantity( - data=grav_var_h_np, - dims=example_dims, - units="grav_var_h units", - number_of_halo_points=n_halo, - backend=example_backend, - ) +# grav_var_h_np = np.random.random((nx, ny, nz + 1)) +# expected_grav_var_h_np = copy.deepcopy(grav_var_h_np) +# grav_var_h = Quantity( +# data=grav_var_h_np, +# dims=example_dims, +# units="grav_var_h units", +# number_of_halo_points=n_halo, +# backend=example_backend, +# ) - # Call ad_dyn._average_gravity. This is what we're testing. - ac_dyn._average_gravity(grav_var, grav_var_h) +# # Call ad_dyn._average_gravity. This is what we're testing. +# ac_dyn._average_gravity(grav_var, grav_var_h) - # grav_var_h should be unchanged by the stencil - assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) +# # grav_var_h should be unchanged by the stencil +# assert np.array_equal(grav_var_h.field[:], expected_grav_var_h_np) - expected_grav_var_np = ( - expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] - ) / 2 - assert np.array_equal(grav_var.field[:], expected_grav_var_np) +# expected_grav_var_np = ( +# expected_grav_var_h_np[:, :, :-1] + expected_grav_var_h_np[:, :, 1:] +# ) / 2 +# assert np.array_equal(grav_var.field[:], expected_grav_var_np) -def test_acoustic_dynamics_call_average_gravity() -> None: - # Check that average gravity is called/used in AcousticDynamics call - nx = 12 - ny = 12 - nz = 79 - n_halo = 3 - timestep = 225 # JK TODO: Is this right? +# def test_acoustic_dynamics_call_average_gravity() -> None: +# # Check that average gravity is called/used in AcousticDynamics call +# nx = 12 +# ny = 12 +# nz = 79 +# n_halo = 3 +# timestep = 225 # JK TODO: Is this right? - # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 - ac_dyn, state = setup_acoustic_dynamics( - nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo - ) # nz is hard-coded to 79 +# # JK TODO: Why does the config need npx = nx-(2*n_halo)+1 +# ac_dyn, state = setup_acoustic_dynamics( +# nx - (2 * n_halo) + 1, ny - (2 * n_halo) + 1, n_halo +# ) # nz is hard-coded to 79 - init_grav_var_np = copy.deepcopy(state.grav_var.field) - init_grav_var_h_np = copy.deepcopy(state.grav_var_h.field) +# init_grav_var_np = copy.deepcopy(state.grav_var.field) +# init_grav_var_h_np = copy.deepcopy(state.grav_var_h.field) - ac_dyn(state, timestep) +# ac_dyn(state, timestep) - # The state.grav_var_h should be unchanged by the stencil. - assert np.array_equal(state.grav_var_h.field[:], init_grav_var_h_np) +# # The state.grav_var_h should be unchanged by the stencil. +# assert np.array_equal(state.grav_var_h.field[:], init_grav_var_h_np) - # Check that the state.grav_var values match expectation: - expected_grav_var_np = ( - init_grav_var_h_np[:, :, :-1] + init_grav_var_h_np[:, :, 1:] - ) / 2 - assert np.array_equal(state.grav_var.field[:], expected_grav_var_np) +# # Check that the state.grav_var values match expectation: +# expected_grav_var_np = ( +# init_grav_var_h_np[:, :, :-1] + init_grav_var_h_np[:, :, 1:] +# ) / 2 +# assert np.array_equal(state.grav_var.field[:], expected_grav_var_np) -""" -E + where False = ( +# """ +# E + where False = ( -array([ -[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), +# array([ +# [[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]]), -array([[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]])) +# array([[[0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n ...\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.],\n [0., 0., 0., ..., 0., 0., 0.]]])) -E + where = np.array_equal +# E + where = np.array_equal -""" +# """ -############################ fv_dynamics.py +# ############################ fv_dynamics.py def test_init_gravity() -> None: From 4db8304f68cfc7f89824a65ed4c613396a110624 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Tue, 13 Jan 2026 15:56:18 -0500 Subject: [PATCH 7/7] Lint --- pyfv3/stencils/dyn_core.py | 6 ++++-- pyfv3/stencils/fv_dynamics.py | 2 +- pyfv3/stencils/wam.py | 17 +++++++++++++---- tests/main/test_wam.py | 9 ++++++--- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/pyfv3/stencils/dyn_core.py b/pyfv3/stencils/dyn_core.py index 4ace486d..4d5453ca 100644 --- a/pyfv3/stencils/dyn_core.py +++ b/pyfv3/stencils/dyn_core.py @@ -6,14 +6,15 @@ import ndsl.constants as constants import ndsl.stencils.basic_operations as basic import pyfv3.stencils.d_sw as d_sw -import pyfv3.stencils.wam as wam import pyfv3.stencils.nh_p_grad as nh_p_grad import pyfv3.stencils.pe_halo as pe_halo import pyfv3.stencils.ray_fast as ray_fast import pyfv3.stencils.temperature_adjust as temperature_adjust + # import pyfv3.stencils.rdg_adjust as rdg_adjust import pyfv3.stencils.updatedzc as updatedzc import pyfv3.stencils.updatedzd as updatedzd +import pyfv3.stencils.wam as wam from ndsl import ( GridIndexing, Quantity, @@ -92,6 +93,7 @@ def zero_data( heat_source = 0.0 diss_estd = 0.0 + def gz_from_surface_height_and_thicknesses( zs: FloatFieldIJ, delz: FloatField, gz: FloatField ): @@ -1052,5 +1054,5 @@ def __call__( self._heat_source, state.pt, delt_time_factor, - state.rdg_var + state.rdg_var, ) diff --git a/pyfv3/stencils/fv_dynamics.py b/pyfv3/stencils/fv_dynamics.py index 355a2ffa..0f7db7fe 100644 --- a/pyfv3/stencils/fv_dynamics.py +++ b/pyfv3/stencils/fv_dynamics.py @@ -4,8 +4,8 @@ from dace.frontend.python.interface import nounroll as dace_no_unroll import ndsl.dsl.gt4py_utils as utils -import pyfv3.stencils.wam as wam import pyfv3.stencils.moist_cv as moist_cv +import pyfv3.stencils.wam as wam from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater from ndsl.checkpointer import NullCheckpointer from ndsl.comm.mpi import MPI diff --git a/pyfv3/stencils/wam.py b/pyfv3/stencils/wam.py index 52ef7c78..bfcefe57 100644 --- a/pyfv3/stencils/wam.py +++ b/pyfv3/stencils/wam.py @@ -1,9 +1,9 @@ +from ndsl import StencilFactory from ndsl.constants import GRAV, RADIUS, RDGAS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM from ndsl.dsl.gt4py import BACKWARD, FORWARD, computation from ndsl.dsl.gt4py import function as gtfunction from ndsl.dsl.gt4py import interval from ndsl.dsl.typing import FloatField, FloatFieldIJ -from ndsl import StencilFactory @gtfunction @@ -46,6 +46,7 @@ def adjust_gravity( grav_var_h = GRAV * (RADIUS**2) / newrad**2 grav_var = average_gravity(grav_var, grav_var_h) + def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): """ # JK TODO: Is there a better name than this? @@ -57,7 +58,8 @@ def neg_rdgas_div_gravity(rdg: FloatField, grav_var: FloatField): grav_var (in): variable gravity """ with computation(FORWARD), interval(...): - rdg = - RDGAS / grav_var + rdg = -RDGAS / grav_var + class WholeAtmos: def __init__(self, stencil_factory: StencilFactory): @@ -74,6 +76,13 @@ def __init__(self, stencil_factory: StencilFactory): compute_dims=[X_DIM, Y_DIM, Z_DIM], ) - def __call__(self, grav_var: FloatField, grav_var_h: FloatField, rdg: FloatField, phis: FloatFieldIJ, delz: FloatField): + def __call__( + self, + grav_var: FloatField, + grav_var_h: FloatField, + rdg: FloatField, + phis: FloatFieldIJ, + delz: FloatField, + ): self.constructed_adjust_gravity_stencil(grav_var, grav_var_h, phis, delz) - self.constructed_neg_rdgas_div_gravity_stencil(rdg, grav_var) \ No newline at end of file + self.constructed_neg_rdgas_div_gravity_stencil(rdg, grav_var) diff --git a/tests/main/test_wam.py b/tests/main/test_wam.py index 75ab447e..78d4decb 100644 --- a/tests/main/test_wam.py +++ b/tests/main/test_wam.py @@ -28,9 +28,12 @@ from ndsl.stencils.basic_operations import set_value from pyfv3 import DycoreState, DynamicalCore, DynamicalCoreConfig from pyfv3.initialization.analytic_init import AnalyticCase -from pyfv3.stencils.dyn_core import AcousticDynamics -from pyfv3.stencils.wam import adjust_gravity, average_gravity_stencil_defn, neg_rdgas_div_gravity -from pyfv3.stencils.dyn_core import compute_geopotential +from pyfv3.stencils.dyn_core import AcousticDynamics, compute_geopotential +from pyfv3.stencils.wam import ( + adjust_gravity, + average_gravity_stencil_defn, + neg_rdgas_div_gravity, +) # JK NOTE TODO: Just sticking things in here for now,