Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/notebook/test_functionality.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
")\n",
"\n",
"# useful for easily allocating distributed data storages (fields)\n",
"quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend)\n",
"quantity_factory = QuantityFactory(sizer=sizer, backend=backend)\n",
"\n",
"compilation_config = CompilationConfig(backend=backend, communicator=cs_communicator)\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions pyfv3/dycore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class DycoreState:
rdg_var: Quantity = field(
metadata={
"name": "gas constant for dry air over variable gravity (RDGAS / grav_var) for Whole Atmosphere calculations",
"units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units?
"units": "(J/kg/deg) / (m s^-2)", # JK TODO: What are the units?
"dims": [X_DIM, Y_DIM, Z_DIM],
"intent": "inout",
}
Expand Down Expand Up @@ -368,7 +368,7 @@ def init_from_numpy_arrays(
_field.metadata["units"],
origin=sizer.get_origin(dims),
extent=sizer.get_extent(dims),
gt4py_backend=backend,
backend=backend,
)
state = cls(**dict_state)
return state
Expand Down
2 changes: 1 addition & 1 deletion pyfv3/initialization/test_cases/initialize_baroclinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def init_baroclinic_state(
state = DycoreState.init_from_numpy_arrays(
numpy_state.__dict__,
sizer=quantity_factory.sizer,
backend=sample_quantity.metadata.gt4py_backend,
backend=sample_quantity.metadata.backend,
)

comm.halo_update(state.phis, n_points=NHALO)
Expand Down
2 changes: 1 addition & 1 deletion pyfv3/initialization/test_cases/initialize_rossby.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def init_rossby_state(
state = DycoreState.init_from_numpy_arrays(
numpy_state.__dict__,
sizer=quantity_factory.sizer,
backend=sample_quantity.metadata.gt4py_backend,
backend=sample_quantity.metadata.backend,
)

comm.halo_update(state.phis, n_points=NHALO)
Expand Down
2 changes: 1 addition & 1 deletion pyfv3/initialization/test_cases/initialize_tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def init_tc_state(
state = DycoreState.init_from_numpy_arrays(
numpy_state.__dict__,
sizer=quantity_factory.sizer,
backend=sample_quantity.metadata.gt4py_backend,
backend=sample_quantity.metadata.backend,
)

return state
2 changes: 1 addition & 1 deletion pyfv3/stencils/delnflux.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def calc_damp(damp_c: Quantity, da_min: Float, nord: Quantity) -> Quantity:
units="unknown",
origin=damp_c.origin,
extent=damp_c.extent,
gt4py_backend=damp_c.gt4py_backend,
backend=damp_c.backend,
)


Expand Down
44 changes: 20 additions & 24 deletions pyfv3/stencils/dyn_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import pyfv3.stencils.pe_halo as pe_halo
import pyfv3.stencils.ray_fast as ray_fast
import pyfv3.stencils.temperature_adjust as temperature_adjust
import pyfv3.stencils.rdg_adjust as rdg_adjust

# import pyfv3.stencils.rdg_adjust as rdg_adjust
import pyfv3.stencils.updatedzc as updatedzc
import pyfv3.stencils.updatedzd as updatedzd
import pyfv3.stencils.wam as wam
from ndsl import (
GridIndexing,
Quantity,
Expand Down Expand Up @@ -92,16 +94,6 @@ def zero_data(
diss_estd = 0.0


def average_gravity(grav_var: FloatField, grav_var_h: FloatField):
"""
Args:
grav_var (out): gravity field
grav_var_h (in): gravity value at height
"""
with computation(FORWARD), interval(...):
grav_var[0, 0, 0] = 0.5*(grav_var_h[0, 0, 0] + grav_var_h[0, 0, 1])


def gz_from_surface_height_and_thicknesses(
zs: FloatFieldIJ, delz: FloatField, gz: FloatField
):
Expand Down Expand Up @@ -666,16 +658,17 @@ def __init__(
grav_var_h=state.grav_var_h,
)

self._average_gravity = stencil_factory.from_origin_domain(
average_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
)
self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain(
rdg_adjust.neg_rdgas_div_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
)
# self._average_gravity = stencil_factory.from_origin_domain(
# wam.average_gravity_stencil_defn,
# origin=grid_indexing.origin_full(),
# domain=grid_indexing.domain_full(),
# )

# self._neg_rdgas_div_gravity = stencil_factory.from_origin_domain(
# wam.neg_rdgas_div_gravity,
# origin=grid_indexing.origin_full(),
# domain=grid_indexing.domain_full(),
# )

# See divergence_damping.py, _get_da_min for explanation of this function
@dace_inhibitor
Expand Down Expand Up @@ -815,8 +808,9 @@ def __call__(
if it == 0:
self._halo_updaters.delp__pt.wait()
self._halo_updaters.grav_var_h.update()
self._average_gravity(state.grav_var, state.grav_var_h)
self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var)
# ALREADY HAPPENING IN DYNAMICAL CORE CALL
# self._average_gravity(state.grav_var, state.grav_var_h)
# self._neg_rdgas_div_gravity(state.rdg_var, state.grav_var)

if it == n_split - 1 and end_step:
if self.config.use_old_omega:
Expand Down Expand Up @@ -892,6 +886,7 @@ def __call__(
self._gz,
self._pkc,
state.omga,
state.grav_var,
)

self._p_grad_c(
Expand Down Expand Up @@ -979,6 +974,7 @@ def __call__(
state.pk,
state.peln,
state.w,
state.grav_var,
)

self._halo_updaters.zh.start()
Expand Down Expand Up @@ -1058,5 +1054,5 @@ def __call__(
self._heat_source,
state.pt,
delt_time_factor,
state.rdg_var
state.rdg_var,
)
82 changes: 23 additions & 59 deletions pyfv3/stencils/fv_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

import ndsl.dsl.gt4py_utils as utils
import pyfv3.stencils.moist_cv as moist_cv
import pyfv3.stencils.wam as wam
from ndsl import Quantity, QuantityFactory, StencilFactory, WrappedHaloUpdater
from ndsl.checkpointer import NullCheckpointer
from ndsl.comm.mpi import MPI
from ndsl.constants import KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR, GRAV, RADIUS
from ndsl.constants import GRAV, KAPPA, NQ, X_DIM, Y_DIM, Z_DIM, Z_INTERFACE_DIM, ZVIR
from ndsl.dsl.dace.orchestration import dace_inhibitor, orchestrate
from ndsl.dsl.gt4py import FORWARD, BACKWARD, PARALLEL, computation, interval
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ
from ndsl.dsl.gt4py import PARALLEL, computation, interval
from ndsl.dsl.typing import Float, FloatField
from ndsl.grid import DampingCoefficients, GridData
from ndsl.logging import ndsl_log
from ndsl.performance import NullTimer, Timer
from ndsl.stencils.basic_operations import copy_defn
from ndsl.stencils.basic_operations import copy_defn, set_value
from ndsl.stencils.c2l_ord import CubedToLatLon
from ndsl.typing import Checkpointer, Communicator
from pyfv3._config import DynamicalCoreConfig
Expand All @@ -25,7 +26,6 @@
from pyfv3.stencils.dyn_core import AcousticDynamics
from pyfv3.stencils.neg_adj3 import AdjustNegativeTracerMixingRatio
from pyfv3.stencils.remapping import LagrangianToEulerian
import pyfv3.stencils.rdg_adjust as rdg_adjust


def pt_to_potential_density_pt(
Expand Down Expand Up @@ -76,45 +76,6 @@ def fvdyn_temporaries(
tmps[name] = quantity
return tmps

def init_gravity(grav_var: FloatField):
"""
Args:
grav_var (out): gravity field
"""
with computation(PARALLEL), interval(...):
grav_var = GRAV

def init_gravity_h(grav_var_h: FloatField):
"""
Args:
grav_var_h (out): gravity field
"""
with computation(PARALLEL), interval(...):
grav_var_h = GRAV

def adjust_gravity(
grav_var: FloatField,
grav_var_h: FloatField,
phis: FloatFieldIJ,
delz: FloatField,
):
"""
Args:
grav_var (out): gravity field
grav_var_h (out): height based gravity
phis (out):
delz (out):
"""
with computation(FORWARD), interval(-1,None):
newrad = RADIUS + (phis/GRAV)
grav_var_h = GRAV*(RADIUS**2)/newrad**2

with computation(BACKWARD), interval(0,-1):
newrad = RADIUS + (phis/GRAV)
newrad = newrad - delz
grav_var_h = GRAV*(RADIUS**2)/newrad**2
grav_var = 0.5*(grav_var_h[0, 0, 1] + grav_var_h[0, 0, 0])


@dace_inhibitor
def log_on_rank_0(msg: str):
Expand Down Expand Up @@ -303,22 +264,17 @@ def __init__(
domain=grid_indexing.domain_full(),
)
self._init_gravity = stencil_factory.from_origin_domain(
init_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
)
self._init_gravity_h = stencil_factory.from_origin_domain(
init_gravity_h,
set_value,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
domain=grid_indexing.domain_full(add=(0, 0, 1)),
)
self._adjust_gravity = stencil_factory.from_origin_domain(
adjust_gravity,
wam.adjust_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(add=(0,0,1)),
domain=grid_indexing.domain_full(add=(0, 0, 1)),
)
self._adjust_rdg = stencil_factory.from_origin_domain(
rdg_adjust.neg_rdgas_div_gravity,
wam.neg_rdgas_div_gravity,
origin=grid_indexing.origin_full(),
domain=grid_indexing.domain_full(),
)
Expand Down Expand Up @@ -387,7 +343,10 @@ def __init__(
comm.get_scalar_halo_updater([full_xyz_spec]), state, ["omga"], comm=comm
)
self._gravity_halo_updater = WrappedHaloUpdater(
comm.get_scalar_halo_updater([full_xyz_spec]), state, ["grav_var"], comm=comm
comm.get_scalar_halo_updater([full_xyz_spec]),
state,
["grav_var"],
comm=comm,
)
self._n_split = config.n_split
self._k_split = config.k_split
Expand Down Expand Up @@ -544,11 +503,14 @@ def compute_preamble(self, state: DycoreState, is_root_rank: bool):
self._dp_initial,
)

self._init_gravity(state.grav_var)
self._init_gravity_h(state.grav_var_h)
self._init_gravity(state.grav_var, GRAV)
self._init_gravity(state.grav_var_h, GRAV)
Comment on lines +506 to +507
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like how the single _init_gravity can be used for both grav_var and grav_var_h now.


if self.config.enable_wam:
self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz)
self._adjust_gravity(
state.grav_var, state.grav_var_h, state.phis, state.delz
)
self._gravity_halo_updater.update()

self._adjust_rdg(state.rdg_var, state.grav_var)

Expand Down Expand Up @@ -677,7 +639,9 @@ def _compute(self, state: DycoreState, timer: Timer):
)
self._checkpoint_remapping_out(state)
if self.config.enable_wam:
self._adjust_gravity(state.grav_var, state.grav_var_h, state.phis, state.delz)
self._adjust_gravity(
state.grav_var, state.grav_var_h, state.phis, state.delz
)
# TODO: can we pull this block out of the loop intead of
# using an if-statement?
if last_step:
Expand Down
8 changes: 4 additions & 4 deletions pyfv3/stencils/fv_subgridz.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
CP_VAP,
CV_AIR,
CV_VAP,
GRAV,
RDGAS,
X_DIM,
Y_DIM,
Expand All @@ -28,7 +27,6 @@
from gt4py.cartesian.gtscript import __INLINED # isort:skip

RK = CP_AIR / RDGAS + 1.0
G2 = 0.5 * GRAV
T1_MIN = 160.0
T2_MIN = 165.0
T2_MAX = 315.0
Expand Down Expand Up @@ -95,6 +93,7 @@ def init(
qo3mr: FloatField,
qsgs_tke: FloatField,
qcld: FloatField,
grav_var: FloatField,
):
with computation(PARALLEL), interval(...):
t0 = ta
Expand All @@ -117,11 +116,11 @@ def init(
cpm, cvm = standard_cm(
cpm, cvm, q0_vapor, q0_liquid, q0_rain, q0_ice, q0_snow, q0_graupel
)
gz = gzh[0, 0, 1] - G2 * delz
gz = gzh[0, 0, 1] - grav_var * delz
tmp = tvol(gz, u0, v0, w0)
static_energy = cpm * t0 + tmp
total_energy = cvm * t0 + tmp
gzh = gzh[0, 0, 1] - GRAV * delz
gzh = gzh[0, 0, 1] - grav_var * delz


@gtfunction
Expand Down Expand Up @@ -903,6 +902,7 @@ def __call__(
state.qo3mr,
state.qsgs_tke,
state.qcld,
state.grav_var,
)

for n in range(self._m):
Expand Down
1 change: 1 addition & 0 deletions pyfv3/stencils/moist_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def compute_pkz_func(delp, delz, pt, cappa, rdg_var):
# TODO use the exponential form for closer answer matching
return exp(cappa * log(rdg_var * delp / delz * pt))


def moist_pkz(
qvapor: FloatField,
qliquid: FloatField,
Expand Down
17 changes: 0 additions & 17 deletions pyfv3/stencils/rdg_adjust.py

This file was deleted.

Loading