Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 6 additions & 26 deletions pyfv3/stencils/fv_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ndsl.grid import DampingCoefficients, GridData
from ndsl.logging import ndsl_log
from ndsl.performance import NullTimer, Timer
from ndsl.stencils.basic_operations import copy_defn
from ndsl.stencils.basic_operations import copy_defn, set_value
from ndsl.stencils.c2l_ord import CubedToLatLon
from ndsl.typing import Checkpointer, Communicator
from pyfv3._config import DynamicalCoreConfig
Expand Down Expand Up @@ -75,22 +75,6 @@ def fvdyn_temporaries(
tmps[name] = quantity
return tmps

def init_gravity(grav_var: FloatField):
"""
Args:
grav_var (out): gravity field
"""
with computation(PARALLEL), interval(...):
grav_var = GRAV

def init_gravity_h(grav_var_h: FloatField):
"""
Args:
grav_var_h (out): gravity field
"""
with computation(PARALLEL), interval(...):
grav_var_h = GRAV

def adjust_gravity(
grav_var: FloatField,
grav_var_h: FloatField,
Expand Down Expand Up @@ -302,14 +286,9 @@ def __init__(
domain=grid_indexing.domain_full(),
)
self._init_gravity = stencil_factory.from_origin_domain(
init_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
)
self._init_gravity_h = stencil_factory.from_origin_domain(
init_gravity_h,
set_value,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
domain=grid_indexing.domain_full(add=(0,0,1)),
)
self._adjust_gravity = stencil_factory.from_origin_domain(
adjust_gravity,
Expand Down Expand Up @@ -538,8 +517,9 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool):
self._dp_initial,
)

self._init_gravity(state.grav_var)
self._init_gravity_h(state.grav_var_h)
# self._init_gravity(state.grav_var, state.grav_var_h)
self._init_gravity(state.grav_var, GRAV)
self._init_gravity(state.grav_var_h, GRAV)
Comment on lines +506 to +507
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how the single _init_gravity can be used for both grav_var and grav_var_h now.


if self.config.enable_wam:
self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz)
Expand Down
26 changes: 17 additions & 9 deletions tests/main/test_wam.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
)
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
from ndsl.dsl.typing import Float, FloatField
from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.constants import GRAV, RDGAS, RADIUS, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM
from ndsl.dsl.gt4py import stencil
from ndsl.stencils.basic_operations import set_value
from pyfv3 import DynamicalCore, DynamicalCoreConfig, DycoreState
from pyfv3.initialization import init_utils
from pyfv3.initialization.analytic_init import AnalyticCase
from pyfv3.stencils.dyn_core import AcousticDynamics
from pyfv3.stencils.fv_dynamics import adjust_gravity, init_gravity, init_gravity_h
from pyfv3.stencils.fv_dynamics import adjust_gravity
from pyfv3.stencils.dyn_core import average_gravity, compute_geopotential

# use numpy for now until I figure out how to use FloatField
Expand Down Expand Up @@ -172,8 +173,8 @@ def test_neg_rdgas_div_gravity() -> None:
backend=example_backend,
)

init_gravity_stencil = stencil(backend=example_backend, definition=init_gravity)
init_gravity_stencil(grav_var)
init_gravity_stencil = stencil(backend=example_backend, definition=set_value)
init_gravity_stencil(grav_var, GRAV)
expected_grav_var_np = copy.deepcopy(grav_var.field[:])

neg_rdgas_div_gravity_stencil = stencil(backend=example_backend, definition=neg_rdgas_div_gravity)
Expand Down Expand Up @@ -491,18 +492,25 @@ def test_init_gravity() -> None:
grid_indexing = GridIndexing.from_sizer_and_communicator(sizer=sizer, comm=communicator)
stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)
quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend)
init_gravity_stencil = stencil_factory.from_dims_halo(
init_gravity,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
compute_halos=(n_halo, n_halo),
init_gravity_stencil = stencil_factory.from_origin_domain(
set_value,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(add=(0,0,1)),
)
grav_var: Quantity = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_DIM],
units="test",
dtype=Float,
)
init_gravity_stencil(grav_var)
grav_var_h: Quantity = quantity_factory.zeros(
[X_DIM, Y_DIM, Z_INTERFACE_DIM],
units="test",
dtype=Float,
)
init_gravity_stencil(grav_var, GRAV)
init_gravity_stencil(grav_var_h, GRAV)
assert np.all(grav_var.field == GRAV)
assert np.all(grav_var_h.field == GRAV)
# JK TODO: There's so much setup... Find a simpler way to set of stencil and quantity?


Expand Down