diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py index a0d9b2fb45..74e3d6f062 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/saturation_adjustment.py @@ -15,14 +15,17 @@ from icon4py.model.atmosphere.subgrid_scale_physics.microphysics.stencils import ( saturation_adjustment_stencils as satad_stencils, ) -from icon4py.model.common import field_type_aliases as fa, model_options, type_alias as ta +from icon4py.model.common import ( + field_type_aliases as fa, + model_backends, + model_options, + type_alias as ta, +) from icon4py.model.common.grid import horizontal as h_grid from icon4py.model.common.utils import data_allocation as data_alloc if TYPE_CHECKING: - import gt4py.next.typing as gtx_typing - from icon4py.model.common.grid import icon as icon_grid, vertical as v_grid from icon4py.model.common.states import model @@ -100,14 +103,15 @@ def __init__( grid: icon_grid.IconGrid, vertical_params: v_grid.VerticalGrid, metric_state: MetricStateSaturationAdjustment, - backend: gtx_typing.Backend | None, + backend: model_backends.BackendLike, ): self._backend = backend + self._allocator = model_backends.get_allocator(backend) self.config = config self._grid = grid self._vertical_params: v_grid.VerticalGrid = vertical_params self._metric_state: MetricStateSaturationAdjustment = metric_state - self._xp = data_alloc.import_array_ns(self._backend) + self._xp = data_alloc.import_array_ns(self._allocator) self._allocate_local_variables() self._determine_horizontal_domains() @@ -123,23 +127,23 @@ def output_properties(self) -> dict[str, model.FieldMetaData]: def _allocate_local_variables(self): #: it was originally named as tworkold in ICON. Old temperature before iteration. self._temperature1 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) #: it was originally named as twork in ICON. New temperature before iteration. self._temperature2 = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) #: A mask that indicates whether the grid cell is subsaturated or not. self._subsaturated_mask = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._allocator ) #: A mask that indicates whether next Newton iteration is required. self._newton_iteration_mask = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=bool, allocator=self._allocator ) #: latent heat vaporization / dry air heat capacity at constant volume self._lwdocvd = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) def _initialize_gt4py_programs(self): @@ -203,7 +207,7 @@ def _initialize_gt4py_programs(self): def _determine_horizontal_domains(self): cell_domain = h_grid.domain(dims.CellDim) self._start_cell_nudging = self._grid.start_index(cell_domain(h_grid.Zone.NUDGING)) - self._end_cell_local = self._grid.start_index(cell_domain(h_grid.Zone.END)) + self._end_cell_local = self._grid.end_index(cell_domain(h_grid.Zone.LOCAL)) def _not_converged(self) -> bool: return self._xp.any( @@ -215,14 +219,14 @@ def _not_converged(self) -> bool: def run( self, - dtime: ta.wpfloat, + temperature_tendency: fa.CellKField[ta.wpfloat], + qv_tendency: fa.CellKField[ta.wpfloat], + qc_tendency: fa.CellKField[ta.wpfloat], rho: fa.CellKField[ta.wpfloat], temperature: fa.CellKField[ta.wpfloat], qv: fa.CellKField[ta.wpfloat], qc: fa.CellKField[ta.wpfloat], - temperature_tendency: fa.CellKField[ta.wpfloat], - qv_tendency: fa.CellKField[ta.wpfloat], - qc_tendency: fa.CellKField[ta.wpfloat], + dtime: ta.wpfloat, ): """ Adjust saturation at each grid point. @@ -238,14 +242,14 @@ def run( Originally inspired from satad_v_3D of ICON. Args: - dtime: time step [s] + temperature_tendency: air temperature tendency [K s-1] + qv_tendency: specific humidity tendency [s-1] + qc_tendency: specific cloud water content tendency [s-1] rho: air density [kg m-3] temperature: air temperature [K] qv: specific humidity [kg kg-1] qc: specific cloud water content [kg kg-1] - temperature_tendency: air temperature tendency [K s-1] - qv_tendency: specific humidity tendency [s-1] - qc_tendency: specific cloud water content tendency [s-1] + dtime: time step [s] """ temperature_pair = common_utils.TimeStepPair(self._temperature1, self._temperature2) diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py index a62413314a..4f0ac32d5a 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/src/icon4py/model/atmosphere/subgrid_scale_physics/microphysics/single_moment_six_class_gscp_graupel.py @@ -22,6 +22,7 @@ constants as physics_constants, dimension as dims, field_type_aliases as fa, + model_backends, model_options, type_alias as ta, ) @@ -30,8 +31,6 @@ if TYPE_CHECKING: - import gt4py.next.typing as gtx_typing - from icon4py.model.common.grid import icon as icon_grid, vertical as v_grid @@ -86,6 +85,8 @@ class SingleMomentSixClassIconGraupelConfig: rain_n0: ta.wpfloat = 1.0 #: coefficient for snow-graupel conversion by riming. Originally defined as csg in mo_nwp_tuning_config.f90 in ICON. snow2graupel_riming_coeff: ta.wpfloat = 0.5 + #: cloud number concentration [1/m3] + cloud_number_concentration: ta.wpfloat = 100.0e6 @dataclasses.dataclass @@ -96,18 +97,19 @@ class MetricStateIconGraupel: class SingleMomentSixClassIconGraupel: def __init__( self, - graupel_config: SingleMomentSixClassIconGraupelConfig, + config: SingleMomentSixClassIconGraupelConfig, grid: icon_grid.IconGrid, metric_state: MetricStateIconGraupel, vertical_params: v_grid.VerticalGrid, - backend: gtx_typing.Backend | None, + backend: model_backends.BackendLike, ): - self.config = graupel_config + self.config = config self._initialize_configurable_parameters() self._grid = grid self._metric_state = metric_state self.vertical_params = vertical_params self._backend = backend + self._allocator = model_backends.get_allocator(backend) self._initialize_local_fields() self._determine_horizontal_domains() @@ -218,50 +220,53 @@ def _initialize_configurable_parameters(self): ) def _initialize_local_fields(self): + self.qnc = data_alloc.zero_field( + self._grid, dims.CellDim, dtype=ta.wpfloat, allocator=self._allocator + ) self.rhoqrv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.rhoqsv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.rhoqgv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.rhoqiv_old_kup = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.vnew_r = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.vnew_s = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.vnew_g = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.vnew_i = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.rain_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.snow_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.graupel_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.ice_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) self.total_precipitation_flux = data_alloc.zero_field( - self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._backend + self._grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator ) def _determine_horizontal_domains(self): cell_domain = h_grid.domain(dims.CellDim) self._start_cell_nudging = self._grid.start_index(cell_domain(h_grid.Zone.NUDGING)) - self._end_cell_local = self._grid.start_index(cell_domain(h_grid.Zone.END)) + self._end_cell_local = self._grid.end_index(cell_domain(h_grid.Zone.LOCAL)) def _initialize_gt4py_programs(self): self._icon_graupel = model_options.setup_program( @@ -346,7 +351,6 @@ def run( qi: fa.CellKField[ta.wpfloat], qs: fa.CellKField[ta.wpfloat], qg: fa.CellKField[ta.wpfloat], - qnc: fa.CellField[ta.wpfloat], temperature_tendency: fa.CellKField[ta.wpfloat], qv_tendency: fa.CellKField[ta.wpfloat], qc_tendency: fa.CellKField[ta.wpfloat], @@ -389,7 +393,7 @@ def run( qr=qr, qs=qs, qg=qg, - qnc=qnc, + qnc=self.qnc, temperature_tendency=temperature_tendency, qv_tendency=qv_tendency, qc_tendency=qc_tendency, diff --git a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py index 20757bd7a1..69de0f0baa 100644 --- a/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py +++ b/model/atmosphere/subgrid_scale_physics/microphysics/tests/microphysics/integration_tests/test_single_moment_six_class_gscp_graupel.py @@ -15,7 +15,7 @@ microphysics_options as mphys_options, single_moment_six_class_gscp_graupel as graupel, ) -from icon4py.model.common import dimension as dims, type_alias as ta +from icon4py.model.common import dimension as dims, model_backends, type_alias as ta from icon4py.model.common.grid import vertical as v_grid from icon4py.model.common.states import ( diagnostic_state as diagnostics, @@ -94,11 +94,15 @@ def test_graupel( temperature=entry_savepoint.temperature(), virtual_temperature=None, pressure=entry_savepoint.pressure(), - pressure_ifc=None, + pressure_at_half_levels=None, u=None, v=None, ) + allocator = model_backends.get_allocator(backend) + xp = data_alloc.import_array_ns(allocator) + qnc = xp.mean(entry_savepoint.qnc()) + graupel_config = graupel.SingleMomentSixClassIconGraupelConfig( liquid_autoconversion_option=mphys_options.LiquidAutoConversionType.SEIFERT_BEHENG, ice_stickeff_min=0.01, @@ -108,18 +112,17 @@ def test_graupel( rain_mu=0.0, rain_n0=1.0, snow2graupel_riming_coeff=0.5, + cloud_number_concentration=qnc, ) graupel_microphysics = graupel.SingleMomentSixClassIconGraupel( - graupel_config=graupel_config, + config=graupel_config, grid=icon_grid, metric_state=metric_state, vertical_params=vertical_params, backend=backend, ) - qnc = entry_savepoint.qnc() - temperature_tendency = data_alloc.zero_field( icon_grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=backend ) @@ -153,7 +156,6 @@ def test_graupel( tracer_state.qi, tracer_state.qs, tracer_state.qg, - qnc, temperature_tendency, qv_tendency, qc_tendency, diff --git a/model/common/src/icon4py/model/common/diagnostic_calculations/stencils/__init__.py b/model/common/src/icon4py/model/common/diagnostic_calculations/__init__.py similarity index 100% rename from model/common/src/icon4py/model/common/diagnostic_calculations/stencils/__init__.py rename to model/common/src/icon4py/model/common/diagnostic_calculations/__init__.py diff --git a/model/common/src/icon4py/model/common/diagnostic_calculations/stencils.py b/model/common/src/icon4py/model/common/diagnostic_calculations/stencils.py new file mode 100644 index 0000000000..3d4018c4f4 --- /dev/null +++ b/model/common/src/icon4py/model/common/diagnostic_calculations/stencils.py @@ -0,0 +1,585 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Final + +import gt4py.next as gtx +from gt4py.next import exp, log, neighbor_sum, sqrt, where + +from icon4py.model.common import ( + constants as phy_const, + dimension as dims, + field_type_aliases as fa, + type_alias as ta, +) +from icon4py.model.common.dimension import E2C, E2CDim, Koff +from icon4py.model.common.states import tracer_state +from icon4py.model.common.type_alias import wpfloat + + +physics_constants: Final = phy_const.PhysicsConstants() + + +@gtx.field_operator +def _diagnose_surface_pressure( + exner: fa.CellKField[ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + ddqz_z_full: fa.CellKField[ta.wpfloat], +) -> fa.CellKField[ta.wpfloat]: + """ + Diagnose surface pressure by assuming hydrostatic balance (dp/dz = -rho g = - p g / Rd / Tv). + Note that virtual temperature is used in the equation to include the moist effect. + + Args: + exner: exner function + virtual_temperature): virtual temperature [K] + ddqz_z_full: vertical grid spacing at full levels [m] + + Returns: + surface pressure: air pressure on the surface (model bottom boundary) [Pa] + """ + surface_pressure = physics_constants.p0ref * exp( + physics_constants.cpd_o_rd * log(exner(Koff[-3])) + + physics_constants.grav_o_rd + * ( + ddqz_z_full(Koff[-1]) / virtual_temperature(Koff[-1]) + + ddqz_z_full(Koff[-2]) / virtual_temperature(Koff[-2]) + + 0.5 * ddqz_z_full(Koff[-3]) / virtual_temperature(Koff[-3]) + ) + ) + return surface_pressure + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def diagnose_surface_pressure( + surface_pressure: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + ddqz_z_full: fa.CellKField[ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _diagnose_surface_pressure( + exner, + virtual_temperature, + ddqz_z_full, + out=surface_pressure, + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.scan_operator(axis=dims.KDim, forward=False, init=(0.0, 0.0, True)) +def _scan_pressure( + state: tuple[ta.wpfloat, ta.wpfloat, bool], + ddqz_z_full: ta.wpfloat, + virtual_temperature: ta.wpfloat, + surface_pressure: ta.wpfloat, +) -> tuple[ta.wpfloat, ta.wpfloat, bool]: + """ + Diagnose pressure at the model full and half levels by assuming hydrostatic balance (dp/dz = -rho g = - p g / Rd / Tv). + Note that virtual temperature is used in the equation to include the moist effect. + The hydrostatic balance is integrated from half levels k-1/2 to k+1/2, and we can obtain the pressure at k+1/2 half level given the pressure at k-1/2 half level. + The pressure at full level k is diagnosed by assuming the geometric mean of the pressure at two adjacent half levels. + + Args: + state: a tuple of (pressure at full levels, pressure at half levels, switch), where switch is True when the current level is the bottommost model level (scan from bottom to top) [Pa] + ddqz_z_full: vertical grid spacing at full levels [m] + virtual_temperature: virtual temperature [K] + surface_pressure: air pressure on the surface (model bottom boundary) [Pa] + + Returns: + pressure at full levels, pressure at half levels [Pa] + """ + pressure_interface = ( + surface_pressure * exp(-physics_constants.grav_o_rd * ddqz_z_full / virtual_temperature) + if state[2] + else state[1] * exp(-physics_constants.grav_o_rd * ddqz_z_full / virtual_temperature) + ) + pressure = ( + sqrt(surface_pressure * pressure_interface) + if state[2] + else sqrt(state[1] * pressure_interface) + ) + return pressure, pressure_interface, False + + +@gtx.field_operator +def _diagnose_pressure( + surface_pressure: gtx.Field[gtx.Dims[dims.CellDim], ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + ddqz_z_full: fa.CellKField[ta.wpfloat], +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + """ + Update pressure by assuming hydrostatic balance (dp/dz = -rho g = - p g / Rd / Tv). + Note that virtual temperature is used in the equation. + + Args: + ddqz_z_full: vertical grid spacing at full levels [m] + virtual_temperature: air virtual temperature [K] + surface_pressure: air pressure on the surface (model bottom boundary) [Pa] + Returns: + pressure at full levels, pressure at half levels (excluding surface level) [Pa] + """ + pressure, pressure_at_half_levels, _ = _scan_pressure( + ddqz_z_full, virtual_temperature, surface_pressure + ) + return pressure, pressure_at_half_levels + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def diagnose_pressure( + pressure: fa.CellKField[ta.wpfloat], + pressure_at_half_levels: fa.CellKField[ta.wpfloat], + surface_pressure: fa.CellField[ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + ddqz_z_full: fa.CellKField[ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _diagnose_pressure( + surface_pressure, + virtual_temperature, + ddqz_z_full, + out=(pressure, pressure_at_half_levels), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _total_hydrometeors( + tracers: tracer_state.TracerState, +) -> fa.CellKField[ta.wpfloat]: + """ + Summation of all hydrometeor mixing ratios. + + Args: + tracers: tracer state containing the mixing ratios of water vapor and all hydrometeors [kg kg-1] + Returns: + total hydrometeor mixing ratio [kg kg-1] + """ + qsum = tracers.qc + tracers.qi + tracers.qr + tracers.qs + tracers.qg + return qsum + + +@gtx.field_operator +def _diagnose_temperature( + tracers: tracer_state.TracerState, + virtual_temperature: fa.CellKField[ta.wpfloat], +) -> fa.CellKField[ta.wpfloat]: + """ + Diagnose temperature . + + Args: + tracers: tracer state containing the mixing ratios of water vapor and all hydrometeors [kg kg-1] + Returns: + total hydrometeor mixing ratio [kg kg-1] + """ + temperature = virtual_temperature / ( + wpfloat("1.0") + + physics_constants.rv_o_rd_minus_1 * tracers.qv + - _total_hydrometeors(tracers) + ) + return temperature + + +@gtx.field_operator +def _diagnose_virtual_temperature( + tracers: tracer_state.TracerState, + temperature: fa.CellKField[ta.wpfloat], +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + virtual_temperature = temperature * ( + wpfloat("1.0") + + physics_constants.rv_o_rd_minus_1 * tracers.qv + - _total_hydrometeors(tracers) + ) + return virtual_temperature, temperature + + +@gtx.field_operator +def _diagnose_virtual_temperature_and_temperature_from_exner( + tracers: tracer_state.TracerState, + theta_v: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + virtual_temperature = theta_v * exner + temperature = _diagnose_temperature(tracers, virtual_temperature) + return virtual_temperature, temperature + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def diagnose_virtual_temperature_and_temperature_from_exner( + virtual_temperature: fa.CellKField[ta.wpfloat], + temperature: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + theta_v: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _diagnose_virtual_temperature_and_temperature_from_exner( + tracers, + theta_v, + exner, + out=(virtual_temperature, temperature), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _diagnose_exner_from_virtual_temperature_and_rho( + virtual_temperature: fa.CellKField[ta.wpfloat], + rho: fa.CellKField[ta.wpfloat], +) -> fa.CellKField[ta.wpfloat]: + exner = exp( + physics_constants.rd_o_cpd * log(physics_constants.rd_o_p0ref * rho * virtual_temperature) + ) + return exner + + +@gtx.field_operator +def _diagnose_exner_from_virtual_temperature( + virtual_temperature: fa.CellKField[ta.wpfloat], + old_virtual_temperature: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], +) -> fa.CellKField[ta.wpfloat]: + exner = exner * ( + wpfloat("1.0") + + physics_constants.rd_o_cpd + * (virtual_temperature / old_virtual_temperature - wpfloat("1.0")) + ) + return exner + + +@gtx.field_operator +def _diagnose_theta_v( + virtual_temperature: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], +) -> fa.CellKField[ta.wpfloat]: + theta_v = virtual_temperature / exner + return theta_v + + +@gtx.field_operator +def _diagnose_exner_and_theta_v_from_virtual_temperature( + perturbed_exner: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature: fa.CellKField[ta.wpfloat], + rho: fa.CellKField[ta.wpfloat], + previous_exner: fa.CellKField[ta.wpfloat], +) -> tuple[ + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], +]: + virtual_temperature = _diagnose_virtual_temperature(tracers, temperature) + exner = _diagnose_exner_from_virtual_temperature_and_rho(virtual_temperature, rho) + perturbed_exner = perturbed_exner + exner - previous_exner + theta_v = _diagnose_theta_v(virtual_temperature, exner) + return virtual_temperature, exner, perturbed_exner, theta_v + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def diagnose_exner_and_theta_v_from_virtual_temperature( + virtual_temperature: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], + perturbed_exner: fa.CellKField[ta.wpfloat], + theta_v: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature: fa.CellKField[ta.wpfloat], + rho: fa.CellKField[ta.wpfloat], + previous_exner: fa.CellKField[ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _diagnose_exner_and_theta_v_from_virtual_temperature( + perturbed_exner, + tracers, + temperature, + rho, + previous_exner, + out=(virtual_temperature, exner, perturbed_exner, theta_v), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _diagnose_virtual_temperature_and_exner( + virtual_temperature: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature: fa.CellKField[ta.wpfloat], +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + new_virtual_temperature = _diagnose_virtual_temperature(tracers, temperature) + new_exner = _diagnose_exner_from_virtual_temperature( + new_virtual_temperature, virtual_temperature, exner + ) + return new_virtual_temperature, new_exner + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def diagnose_virtual_temperature_and_exner( + virtual_temperature: fa.CellKField[ta.wpfloat], + exner: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature: fa.CellKField[ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _diagnose_virtual_temperature_and_exner( + virtual_temperature, + exner, + tracers, + temperature, + out=(virtual_temperature, exner), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _update_exner_and_theta_v_from_virtual_temperature_in_halo( + exner: fa.CellKField[ta.wpfloat], + theta_v: fa.CellKField[ta.wpfloat], + rho: fa.CellKField[ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + mask_prog_halo_c: fa.CellField[bool], +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + exner = where( + mask_prog_halo_c, + _diagnose_exner_from_virtual_temperature_and_rho(virtual_temperature, rho), + exner, + ) + theta_v = where( + mask_prog_halo_c, + _diagnose_theta_v(virtual_temperature=virtual_temperature, exner=exner), + theta_v, + ) + return exner, theta_v + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def update_exner_and_theta_v_from_virtual_temperature_in_halo( + exner: fa.CellKField[ta.wpfloat], + theta_v: fa.CellKField[ta.wpfloat], + rho: fa.CellKField[ta.wpfloat], + virtual_temperature: fa.CellKField[ta.wpfloat], + mask_prog_halo_c: fa.CellField[bool], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _update_exner_and_theta_v_from_virtual_temperature_in_halo( + exner, + theta_v, + rho, + virtual_temperature, + mask_prog_halo_c, + out=(exner, theta_v), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _update_vn_from_u_v_tendencies( + vn: fa.EdgeKField[ta.wpfloat], + u_tendency: fa.CellKField[ta.wpfloat], + v_tendency: fa.CellKField[ta.wpfloat], + dt: ta.wpfloat, + c_lin_e: gtx.Field[gtx.Dims[dims.EdgeDim, E2CDim], ta.wpfloat], + primal_normal_cell_x: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat], + primal_normal_cell_y: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat], +) -> fa.EdgeKField[ta.wpfloat]: + new_vn = vn + dt * neighbor_sum( + c_lin_e * (u_tendency(E2C) * primal_normal_cell_x + v_tendency(E2C) * primal_normal_cell_y), + axis=E2CDim, + ) + return new_vn + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def update_vn_from_u_v_tendencies( + vn: fa.EdgeKField[ta.wpfloat], + u_tendency: fa.CellKField[ta.wpfloat], + v_tendency: fa.CellKField[ta.wpfloat], + dt: ta.wpfloat, + c_lin_e: gtx.Field[gtx.Dims[dims.EdgeDim, E2CDim], ta.wpfloat], + primal_normal_cell_x: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat], + primal_normal_cell_y: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat], + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _update_vn_from_u_v_tendencies( + vn, + u_tendency, + v_tendency, + dt, + c_lin_e, + primal_normal_cell_x, + primal_normal_cell_y, + out=vn, + domain={ + dims.EdgeDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _update_satad_output_from_tendency( + temperature: fa.CellKField[ta.wpfloat], + qv: fa.CellKField[ta.wpfloat], + qc: fa.CellKField[ta.wpfloat], + temperature_tendency: fa.CellKField[ta.wpfloat], + qv_tendency: fa.CellKField[ta.wpfloat], + qc_tendency: fa.CellKField[ta.wpfloat], + dtime: ta.wpfloat, +) -> tuple[fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat], fa.CellKField[ta.wpfloat]]: + """ + Update temperature, qv, and qc from their tendency. + + Args: + temperature: air temperature [K] + qv: specific humidity [kg kg-1] + qc: specific cloud water content [kg kg-1] + temperature_tendency: temperature tendency [K s-1] + qv_tendency: specific humidity tendency [s-1] + qc_tendency: specific cloud water content tendency [s-1] + dtime: time step [s] + Returns: + updated temperature, qv, qc + """ + return ( + temperature + temperature_tendency * dtime, + qv + qv_tendency * dtime, + qc + qc_tendency * dtime, + ) + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def update_satad_output_from_tendency( + temperature: fa.CellKField[ta.wpfloat], + qv: fa.CellKField[ta.wpfloat], + qc: fa.CellKField[ta.wpfloat], + temperature_tendency: fa.CellKField[ta.wpfloat], + qv_tendency: fa.CellKField[ta.wpfloat], + qc_tendency: fa.CellKField[ta.wpfloat], + dtime: ta.wpfloat, + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _update_satad_output_from_tendency( + temperature, + qv, + qc, + temperature_tendency, + qv_tendency, + qc_tendency, + dtime, + out=(temperature, qv, qc), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) + + +@gtx.field_operator +def _update_microphysics_output_from_tendency( + temperature: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature_tendency: fa.CellKField[ta.wpfloat], + tracer_tendency: tracer_state.TracerStateTendency, + dtime: ta.wpfloat, +) -> tuple[ + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], + fa.CellKField[ta.wpfloat], +]: + """ + Update temperature and all hydrometeros from their tendency. + + Args: + temperature: air temperature [K] + tracers: hydrometeor mixing ratios [kg kg-1] + temperature_tendency: temperature tendency [K s-1], + tracer_tendency: tendency of hydrometeor mixing ratios [s-1] + dtime: time step [s] + Returns: + updated temperature, hydrometeor mixing ratios + """ + return ( + temperature + temperature_tendency * dtime, + tracers.qv + tracer_tendency.qv_tendency * dtime, + tracers.qc + tracer_tendency.qc_tendency * dtime, + tracers.qr + tracer_tendency.qr_tendency * dtime, + tracers.qi + tracer_tendency.qi_tendency * dtime, + tracers.qs + tracer_tendency.qs_tendency * dtime, + tracers.qg + tracer_tendency.qg_tendency * dtime, + ) + + +@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) +def update_microphysics_output_from_tendency( + temperature: fa.CellKField[ta.wpfloat], + tracers: tracer_state.TracerState, + temperature_tendency: fa.CellKField[ta.wpfloat], + tracer_tendency: tracer_state.TracerStateTendency, + dtime: ta.wpfloat, + horizontal_start: gtx.int32, + horizontal_end: gtx.int32, + vertical_start: gtx.int32, + vertical_end: gtx.int32, +): + _update_microphysics_output_from_tendency( + temperature, + tracers, + temperature_tendency, + tracer_tendency, + dtime, + out=(temperature, tracers.qv, tracers.qc, tracers.qr, tracers.qi, tracers.qs, tracers.qg), + domain={ + dims.CellDim: (horizontal_start, horizontal_end), + dims.KDim: (vertical_start, vertical_end), + }, + ) diff --git a/model/common/src/icon4py/model/common/diagnostic_calculations/stencils/calculate_tendency.py b/model/common/src/icon4py/model/common/diagnostic_calculations/stencils/calculate_tendency.py deleted file mode 100644 index 72ac0b8f1f..0000000000 --- a/model/common/src/icon4py/model/common/diagnostic_calculations/stencils/calculate_tendency.py +++ /dev/null @@ -1,189 +0,0 @@ -# ICON4Py - ICON inspired code in Python and GT4Py -# -# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Final - -from gt4py import next as gtx - -from icon4py.model.common import ( - constants as phy_const, - dimension as dims, - field_type_aliases as fa, - type_alias as ta, -) -from icon4py.model.common.type_alias import wpfloat - - -physics_constants: Final = phy_const.PhysicsConstants() - - -@gtx.field_operator -def _calculate_virtual_temperature_tendency( - dtime: ta.wpfloat, - qv: fa.CellKField[ta.wpfloat], - qc: fa.CellKField[ta.wpfloat], - qi: fa.CellKField[ta.wpfloat], - qr: fa.CellKField[ta.wpfloat], - qs: fa.CellKField[ta.wpfloat], - qg: fa.CellKField[ta.wpfloat], - temperature: fa.CellKField[ta.wpfloat], - virtual_temperature: fa.CellKField[ta.wpfloat], -) -> fa.CellKField[ta.wpfloat]: - """ - Update virtual temperature tendency. - - Args: - dtime: time step [s] - qv: specific humidity [kg kg-1] - qc: specific cloud water content [kg kg-1] - qi: specific cloud ice content [kg kg-1] - qr: specific rain water content [kg kg-1] - qs: specific snow content [kg kg-1] - qg: specific graupel content [kg kg-1] - temperature: air temperature [K] - virtual_temperature: air virtual temperature [K] - Returns: - virtual temperature tendency [K s-1], exner tendency [s-1], new exner, new virtual temperature [K] - """ - qsum = qc + qi + qr + qs + qg - - new_virtual_temperature = temperature * ( - wpfloat("1.0") + physics_constants.rv_o_rd_minus_1 * qv - qsum - ) - - return (new_virtual_temperature - virtual_temperature) / dtime - - -@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) -def calculate_virtual_temperature_tendency( - dtime: ta.wpfloat, - qv: fa.CellKField[ta.wpfloat], - qc: fa.CellKField[ta.wpfloat], - qi: fa.CellKField[ta.wpfloat], - qr: fa.CellKField[ta.wpfloat], - qs: fa.CellKField[ta.wpfloat], - qg: fa.CellKField[ta.wpfloat], - temperature: fa.CellKField[ta.wpfloat], - virtual_temperature: fa.CellKField[ta.wpfloat], - virtual_temperature_tendency: fa.CellKField[ta.wpfloat], - horizontal_start: gtx.int32, - horizontal_end: gtx.int32, - vertical_start: gtx.int32, - vertical_end: gtx.int32, -): - _calculate_virtual_temperature_tendency( - dtime, - qv, - qc, - qi, - qr, - qs, - qg, - temperature, - virtual_temperature, - out=virtual_temperature_tendency, - domain={ - dims.CellDim: (horizontal_start, horizontal_end), - dims.KDim: (vertical_start, vertical_end), - }, - ) - - -@gtx.field_operator -def _calculate_exner_tendency( - dtime: ta.wpfloat, - virtual_temperature: fa.CellKField[ta.wpfloat], - virtual_temperature_tendency: fa.CellKField[ta.wpfloat], - exner: fa.CellKField[ta.wpfloat], -) -> fa.CellKField[ta.wpfloat]: - """ - Update exner tendency. - - Args: - dtime: time step [s] - virtual_temperature: air virtual temperature [K] - virtual_temperature_tendency: air virtual temperature tendency [K s-1] - exner: exner function - Returns: - exner tendency [s-1] - """ - - new_exner = exner * ( - wpfloat("1.0") - + physics_constants.rd_o_cpd * virtual_temperature_tendency / virtual_temperature * dtime - ) - - return (new_exner - exner) / dtime - - -@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) -def calculate_exner_tendency( - dtime: ta.wpfloat, - virtual_temperature: fa.CellKField[ta.wpfloat], - virtual_temperature_tendency: fa.CellKField[ta.wpfloat], - exner: fa.CellKField[ta.wpfloat], - exner_tendency: fa.CellKField[ta.wpfloat], - horizontal_start: gtx.int32, - horizontal_end: gtx.int32, - vertical_start: gtx.int32, - vertical_end: gtx.int32, -): - _calculate_exner_tendency( - dtime, - virtual_temperature, - virtual_temperature_tendency, - exner, - out=exner_tendency, - domain={ - dims.CellDim: (horizontal_start, horizontal_end), - dims.KDim: (vertical_start, vertical_end), - }, - ) - - -@gtx.field_operator -def _calculate_cell_kdim_field_tendency( - dtime: ta.wpfloat, - old_field: fa.CellKField[ta.wpfloat], - new_field: fa.CellKField[ta.wpfloat], -) -> fa.CellKField[ta.wpfloat]: - """ - Update tendency of Cell-K Dim field. - - Args: - dtime: time step [s] - old_field: any old Cell-K Dim field [unit] - new_field: any new Cell-K Dim field [unit] - Returns: - tendency [unit s-1] - """ - - return (new_field - old_field) / dtime - - -@gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) -def calculate_cell_kdim_field_tendency( - dtime: ta.wpfloat, - old_field: fa.CellKField[ta.wpfloat], - new_field: fa.CellKField[ta.wpfloat], - tendency: fa.CellKField[ta.wpfloat], - horizontal_start: gtx.int32, - horizontal_end: gtx.int32, - vertical_start: gtx.int32, - vertical_end: gtx.int32, -) -> None: - _calculate_cell_kdim_field_tendency( - dtime, - old_field, - new_field, - out=tendency, - domain={ - dims.CellDim: (horizontal_start, horizontal_end), - dims.KDim: (vertical_start, vertical_end), - }, - ) diff --git a/model/common/src/icon4py/model/common/grid/icon.py b/model/common/src/icon4py/model/common/grid/icon.py index eaafcc56fa..74af2a7c65 100644 --- a/model/common/src/icon4py/model/common/grid/icon.py +++ b/model/common/src/icon4py/model/common/grid/icon.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import functools import logging import math from collections.abc import Callable @@ -139,6 +140,42 @@ class IconGrid(base.Grid): default=None, kw_only=True ) + @functools.cached_property + def cell_start_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.CellDim) + cell_start_index_dict = {zone: self.start_index(domain(zone)) for zone in h_grid.Zone} + return cell_start_index_dict + + @functools.cached_property + def cell_end_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.CellDim) + cell_end_index_dict = {zone: self.end_index(domain(zone)) for zone in h_grid.Zone} + return cell_end_index_dict + + @functools.cached_property + def edge_start_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.EdgeDim) + edge_start_index_dict = {zone: self.start_index(domain(zone)) for zone in h_grid.Zone} + return edge_start_index_dict + + @functools.cached_property + def edge_end_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.EdgeDim) + edge_end_index_dict = {zone: self.end_index(domain(zone)) for zone in h_grid.Zone} + return edge_end_index_dict + + @functools.cached_property + def vertex_start_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.VertexDim) + vertex_start_index_dict = {zone: self.start_index(domain(zone)) for zone in h_grid.Zone} + return vertex_start_index_dict + + @functools.cached_property + def vertex_end_index(self) -> dict[h_grid.Zone, gtx.int32]: + domain = h_grid.domain(dims.VertexDim) + vertex_end_index_dict = {zone: self.end_index(domain(zone)) for zone in h_grid.Zone} + return vertex_end_index_dict + def _has_skip_values(offset: gtx.FieldOffset, limited_area_or_distributed: bool) -> bool: """ diff --git a/model/common/src/icon4py/model/common/interpolation/stencils/edge_2_cell_vector_rbf_interpolation.py b/model/common/src/icon4py/model/common/interpolation/stencils/compute_edge_2_cell_vector_interpolation.py similarity index 86% rename from model/common/src/icon4py/model/common/interpolation/stencils/edge_2_cell_vector_rbf_interpolation.py rename to model/common/src/icon4py/model/common/interpolation/stencils/compute_edge_2_cell_vector_interpolation.py index 1fe16025fd..877589c35c 100644 --- a/model/common/src/icon4py/model/common/interpolation/stencils/edge_2_cell_vector_rbf_interpolation.py +++ b/model/common/src/icon4py/model/common/interpolation/stencils/compute_edge_2_cell_vector_interpolation.py @@ -13,7 +13,7 @@ @gtx.field_operator -def _edge_2_cell_vector_rbf_interpolation( +def _compute_edge_2_cell_vector_interpolation( p_e_in: fa.EdgeKField[ta.wpfloat], ptr_coeff_1: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat], ptr_coeff_2: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat], @@ -24,7 +24,6 @@ def _edge_2_cell_vector_rbf_interpolation( The theory is described in Narcowich and Ward (Math Comp. 1994) and Bonaventura and Baudisch (Mox Report n. 75). It takes edge based variables as input and combines them into three dimensional cartesian vectors at each cell center. - TODO(OngChia): This stencil actually just use the c2e2c2e connectivity and the corresponding coefficients to compute cell-center value without knowledge of how the coefficients are computed. A better name is perferred. Args: p_e_in: Input values at edge center. @@ -39,7 +38,7 @@ def _edge_2_cell_vector_rbf_interpolation( @gtx.program(grid_type=gtx.GridType.UNSTRUCTURED) -def edge_2_cell_vector_rbf_interpolation( +def compute_edge_2_cell_vector_interpolation( p_e_in: fa.EdgeKField[ta.wpfloat], ptr_coeff_1: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat], ptr_coeff_2: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat], @@ -50,7 +49,7 @@ def edge_2_cell_vector_rbf_interpolation( vertical_start: gtx.int32, vertical_end: gtx.int32, ) -> None: - _edge_2_cell_vector_rbf_interpolation( + _compute_edge_2_cell_vector_interpolation( p_e_in, ptr_coeff_1, ptr_coeff_2, diff --git a/model/common/src/icon4py/model/common/states/diagnostic_state.py b/model/common/src/icon4py/model/common/states/diagnostic_state.py index 8db53df9e9..cef6cf5ede 100644 --- a/model/common/src/icon4py/model/common/states/diagnostic_state.py +++ b/model/common/src/icon4py/model/common/states/diagnostic_state.py @@ -33,7 +33,7 @@ class DiagnosticState: #: air pressure [Pa] at cell center and full levels, originally defined as pres in ICON pressure: fa.CellKField[ta.wpfloat] #: air pressure [Pa] at cell center and half levels, originally defined as pres_ifc and pres_sfc for surface pressure in ICON. - pressure_ifc: fa.CellKField[ta.wpfloat] + pressure_at_half_levels: fa.CellKField[ta.wpfloat] #: air temperature [K] at cell center, originally defined as temp in ICON temperature: fa.CellKField[ta.wpfloat] #: air virtual temperature [K] at cell center, originally defined as tempv in ICON @@ -45,16 +45,7 @@ class DiagnosticState: @property def surface_pressure(self) -> fa.CellField[ta.wpfloat]: - return gtx.as_field((dims.CellDim,), self.pressure_ifc.ndarray[:, -1]) - - -@dataclasses.dataclass -class DiagnosticMetricState: - """Class that contains the diagnostic metric state for computing the diagnostic state.""" - - ddqz_z_full: fa.CellKField[ta.wpfloat] - rbf_vec_coeff_c1: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat] - rbf_vec_coeff_c2: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat] + return gtx.as_field((dims.CellDim,), self.pressure_at_half_levels.ndarray[:, -1]) def initialize_diagnostic_state( @@ -69,7 +60,7 @@ def initialize_diagnostic_state( allocator=allocator, dtype=ta.wpfloat, ) - pressure_ifc = data_alloc.zero_field( + pressure_at_half_levels = data_alloc.zero_field( grid, dims.CellDim, dims.KDim, @@ -107,7 +98,7 @@ def initialize_diagnostic_state( ) return DiagnosticState( pressure=pressure, - pressure_ifc=pressure_ifc, + pressure_at_half_levels=pressure_at_half_levels, temperature=temperature, virtual_temperature=virtual_temperature, u=u, diff --git a/model/common/src/icon4py/model/common/states/prognostic_state.py b/model/common/src/icon4py/model/common/states/prognostic_state.py index 3e220e0d39..8506998e26 100644 --- a/model/common/src/icon4py/model/common/states/prognostic_state.py +++ b/model/common/src/icon4py/model/common/states/prognostic_state.py @@ -35,9 +35,6 @@ class PrognosticState: ] # horizontal wind normal to edges, vn(nproma, nlev, nblks_e) [m/s] exner: fa.CellKField[ta.wpfloat] # exner function, exner(nrpoma, nlev, nblks_c) theta_v: fa.CellKField[ta.wpfloat] # virtual temperature, (nproma, nlev, nlbks_c) [K] - tracer: list[fa.CellKField[ta.wpfloat]] = dataclasses.field( - default_factory=list - ) # tracer concentration (nproma,nlev,nblks_c,ntracer) [kg/kg] @property def w_1(self) -> fa.CellField[ta.wpfloat]: @@ -47,7 +44,6 @@ def w_1(self) -> fa.CellField[ta.wpfloat]: def initialize_prognostic_state( grid: icon_grid.IconGrid, allocator: gtx_typing.Allocator, - ntracer: int = 0, ) -> PrognosticState: """Initialize the prognostic state with zero fields.""" rho = data_alloc.zero_field( @@ -86,14 +82,4 @@ def initialize_prognostic_state( allocator=allocator, dtype=ta.wpfloat, ) - tracer = [ - data_alloc.zero_field( - grid, - dims.CellDim, - dims.KDim, - allocator=allocator, - dtype=ta.wpfloat, - ) - for _ in range(ntracer) - ] - return PrognosticState(rho=rho, w=w, vn=vn, exner=exner, theta_v=theta_v, tracer=tracer) + return PrognosticState(rho=rho, w=w, vn=vn, exner=exner, theta_v=theta_v) diff --git a/model/common/src/icon4py/model/common/states/static_coefficients.py b/model/common/src/icon4py/model/common/states/static_coefficients.py new file mode 100644 index 0000000000..e72d2c3d8b --- /dev/null +++ b/model/common/src/icon4py/model/common/states/static_coefficients.py @@ -0,0 +1,29 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + import gt4py.next as gtx + + from icon4py.model.common import dimension as dims, field_type_aliases as fa, type_alias as ta + + +@dataclasses.dataclass +class StaticCoeff: + """Class that contains the coefficients to update the prognostic/diagnostic states.""" + + ddqz_z_full: fa.CellKField[ta.wpfloat] + rbf_vec_coeff_c1: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat] + rbf_vec_coeff_c2: gtx.Field[gtx.Dims[dims.CellDim, dims.C2E2C2EDim], ta.wpfloat] + c_lin_e: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat] + primal_normal_cell_x: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat] + primal_normal_cell_y: gtx.Field[gtx.Dims[dims.EdgeDim, dims.E2CDim], ta.wpfloat] diff --git a/model/common/src/icon4py/model/common/states/tracer_state.py b/model/common/src/icon4py/model/common/states/tracer_state.py index befbbe84f1..41d342559b 100644 --- a/model/common/src/icon4py/model/common/states/tracer_state.py +++ b/model/common/src/icon4py/model/common/states/tracer_state.py @@ -8,8 +8,16 @@ import dataclasses +from typing import TYPE_CHECKING -from icon4py.model.common import field_type_aliases as fa, type_alias as ta +from icon4py.model.common import dimension as dims, field_type_aliases as fa, type_alias as ta +from icon4py.model.common.utils import data_allocation as data_alloc + + +if TYPE_CHECKING: + import gt4py.next.typing as gtx_typing + + from icon4py.model.common.grid import icon as icon_grid @dataclasses.dataclass @@ -31,3 +39,98 @@ class TracerState: qs: fa.CellKField[ta.wpfloat] #: specific graupel content [kg/kg] at cell center qg: fa.CellKField[ta.wpfloat] + + def __iter__(self): + for f in dataclasses.fields(self): + yield getattr(self, f.name) + + +@dataclasses.dataclass +class TracerStateTendency: + """ + Class that contains the tendency of the tracer state which includes hydrometeors and aerosols. + """ + + #: specific humidity [kg/kg] at cell center + qv_tendency: fa.CellKField[ta.wpfloat] + #: specific cloud water content [kg/kg] at cell center + qc_tendency: fa.CellKField[ta.wpfloat] + #: specific rain content [kg/kg] at cell center + qr_tendency: fa.CellKField[ta.wpfloat] + #: specific cloud ice content [kg/kg] at cell center + qi_tendency: fa.CellKField[ta.wpfloat] + #: specific snow content [kg/kg] at cell center + qs_tendency: fa.CellKField[ta.wpfloat] + #: specific graupel content [kg/kg] at cell center + qg_tendency: fa.CellKField[ta.wpfloat] + + +@dataclasses.dataclass +class TracerStateScalar: + """ + Class that contains the scalar of the tracer state which includes hydrometeors and aerosols. + """ + + #: specific humidity [kg/kg] + qv: ta.wpfloat + #: specific cloud water content [kg/kg] + qc: ta.wpfloat + #: specific rain content [kg/kg] + qr: ta.wpfloat + #: specific cloud ice content [kg/kg] + qi: ta.wpfloat + #: specific snow content [kg/kg] + qs: ta.wpfloat + #: specific graupel content [kg/kg] + qg: ta.wpfloat + + +def initialize_tracer_state( + grid: icon_grid.IconGrid, + allocator: gtx_typing.Allocator, +) -> TracerState: + """Initialize the tracer state with zero fields.""" + qv = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + qc = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + qr = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + qi = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + qs = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + qg = data_alloc.zero_field( + grid, + dims.CellDim, + dims.KDim, + allocator=allocator, + dtype=ta.wpfloat, + ) + + return TracerState(qv=qv, qc=qc, qr=qr, qi=qi, qs=qs, qg=qg) diff --git a/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py b/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py index b3e6dbb257..7594193097 100644 --- a/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py +++ b/model/common/tests/common/diagnostic_calculations/unit_tests/test_diagnostic_calculations.py @@ -14,15 +14,10 @@ import icon4py.model.common.grid.horizontal as h_grid from icon4py.model.common import dimension as dims -from icon4py.model.common.diagnostic_calculations.stencils import ( - calculate_tendency, - diagnose_pressure, - diagnose_surface_pressure, - diagnose_temperature, -) +from icon4py.model.common.diagnostic_calculations import stencils as diagnostic_stencils from icon4py.model.common.grid import vertical as v_grid -from icon4py.model.common.interpolation.stencils import edge_2_cell_vector_rbf_interpolation as rbf -from icon4py.model.common.states import diagnostic_state as diagnostics, tracer_state as tracers +from icon4py.model.common.interpolation.stencils import compute_edge_2_cell_vector_interpolation +from icon4py.model.common.states import diagnostic_state as diagnostics, tracer_state from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.testing import definitions, test_utils from icon4py.model.testing.fixtures.datatest import ( @@ -62,25 +57,35 @@ def test_diagnose_temperature( virtual_temperature = data_alloc.zero_field( icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend ) + tracers = tracer_state.TracerState( + qv=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + qc=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + qr=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + qi=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + qs=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + qg=data_alloc.zero_field( + icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend + ), + ) - qv = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - qc = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - qr = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - qi = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - qs = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - qg = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, dtype=float, allocator=backend) - - diagnose_temperature.diagnose_virtual_temperature_and_temperature.with_backend(backend)( - qv=qv, - qc=qc, - qr=qr, - qi=qi, - qs=qs, - qg=qg, - theta_v=theta_v, - exner=exner, + diagnostic_stencils.diagnose_virtual_temperature_and_temperature_from_exner.with_backend( + backend + )( virtual_temperature=virtual_temperature, temperature=temperature, + tracers=tracers, + theta_v=theta_v, + exner=exner, horizontal_start=0, horizontal_end=icon_grid.end_index(h_grid.domain(dims.CellDim)(h_grid.Zone.END)), vertical_start=0, @@ -125,7 +130,9 @@ def test_diagnose_meridional_and_zonal_winds( ) end_cell_end = icon_grid.end_index(cell_domain(h_grid.Zone.END)) - rbf.edge_2_cell_vector_rbf_interpolation.with_backend(backend)( + compute_edge_2_cell_vector_interpolation.compute_edge_2_cell_vector_interpolation.with_backend( + backend + )( p_e_in=vn, ptr_coeff_1=rbv_vec_coeff_c1, ptr_coeff_2=rbv_vec_coeff_c2, @@ -173,11 +180,11 @@ def test_diagnose_surface_pressure( cell_domain = h_grid.domain(dims.CellDim) - diagnose_surface_pressure.diagnose_surface_pressure.with_backend(backend)( + diagnostic_stencils.diagnose_surface_pressure.with_backend(backend)( + surface_pressure=surface_pressure, exner=exner, virtual_temperature=virtual_temperature, ddqz_z_full=ddqz_z_full, - surface_pressure=surface_pressure, horizontal_start=0, horizontal_end=icon_grid.end_index(cell_domain(h_grid.Zone.END)), vertical_start=icon_grid.num_levels, @@ -213,18 +220,18 @@ def test_diagnose_pressure( ) cell_domain = h_grid.domain(dims.CellDim) - pressure_ifc = data_alloc.zero_field( + pressure_at_half_levels = data_alloc.zero_field( icon_grid, dims.CellDim, dims.KDim, dtype=float, extend={dims.KDim: 1}, allocator=backend ) - pressure_ifc.ndarray[:, -1] = surface_pressure.ndarray + pressure_at_half_levels.ndarray[:, -1] = surface_pressure.ndarray - diagnose_pressure.diagnose_pressure.with_backend(backend)( - ddqz_z_full, - virtual_temperature, - surface_pressure, - pressure, - pressure_ifc, + diagnostic_stencils.diagnose_pressure.with_backend(backend)( + pressure=pressure, + pressure_at_half_levels=pressure_at_half_levels, + surface_pressure=surface_pressure, + virtual_temperature=virtual_temperature, + ddqz_z_full=ddqz_z_full, horizontal_start=0, horizontal_end=icon_grid.end_index(cell_domain(h_grid.Zone.END)), vertical_start=0, @@ -232,7 +239,7 @@ def test_diagnose_pressure( offset_provider={}, ) - assert test_utils.dallclose(pressure_ifc_ref, pressure_ifc.asnumpy()) + assert test_utils.dallclose(pressure_ifc_ref, pressure_at_half_levels.asnumpy()) assert test_utils.dallclose( pressure_ref, @@ -252,9 +259,6 @@ def test_diagnose_pressure( def test_diagnostic_update_after_saturation_adjustement( location: str, date: str, - model_top_height: float, # TODO(havogt): unused? - damping_height: float, # TODO(havogt): unused? - stretch_factor: float, # TODO(havogt): unused? data_provider: sb.IconSerialDataProvider, grid_savepoint: sb.IconGridSavepoint, metrics_savepoint: sb.MetricSavepoint, @@ -264,20 +268,14 @@ def test_diagnostic_update_after_saturation_adjustement( satad_init = data_provider.from_savepoint_satad_init(location=location, date=date) satad_exit = data_provider.from_savepoint_satad_exit(location=location, date=date) - dtime = 2.0 - vertical_config = v_grid.VerticalGridConfig(icon_grid.num_levels) vertical_params = v_grid.VerticalGrid( config=vertical_config, vct_a=grid_savepoint.vct_a(), vct_b=grid_savepoint.vct_b(), ) - virtual_temperature_tendency = data_alloc.zero_field( - icon_grid, dims.CellDim, dims.KDim, allocator=backend - ) - exner_tendency = data_alloc.zero_field(icon_grid, dims.CellDim, dims.KDim, allocator=backend) - tracer_state = tracers.TracerState( + tracers = tracer_state.TracerState( qv=satad_exit.qv(), qc=satad_exit.qc(), qr=satad_init.qr(), @@ -291,7 +289,7 @@ def test_diagnostic_update_after_saturation_adjustement( temperature=satad_exit.temperature(), virtual_temperature=satad_init.virtual_temperature(), pressure=satad_init.pressure(), - pressure_ifc=satad_init.pressure_ifc(), + pressure_at_half_levels=satad_init.pressure_ifc(), u=None, v=None, ) @@ -299,35 +297,11 @@ def test_diagnostic_update_after_saturation_adjustement( cell_domain = h_grid.domain(dims.CellDim) start_cell_nudging = icon_grid.start_index(cell_domain(h_grid.Zone.NUDGING)) end_cell_local = icon_grid.start_index(cell_domain(h_grid.Zone.END)) - calculate_tendency.calculate_virtual_temperature_tendency.with_backend(backend)( - dtime=dtime, - qv=tracer_state.qv, - qc=tracer_state.qc, - qi=tracer_state.qi, - qr=tracer_state.qr, - qs=tracer_state.qs, - qg=tracer_state.qg, - temperature=diagnostic_state.temperature, + diagnostic_stencils.diagnose_virtual_temperature_and_exner.with_backend(backend)( virtual_temperature=diagnostic_state.virtual_temperature, - virtual_temperature_tendency=virtual_temperature_tendency, - horizontal_start=start_cell_nudging, - horizontal_end=end_cell_local, - vertical_start=vertical_params.kstart_moist, - vertical_end=icon_grid.num_levels, - offset_provider={}, - ) - - updated_virtual_temperature = ( - diagnostic_state.virtual_temperature.asnumpy() - + virtual_temperature_tendency.asnumpy() * dtime - ) - - calculate_tendency.calculate_exner_tendency.with_backend(backend)( - dtime=dtime, - virtual_temperature=diagnostic_state.virtual_temperature, - virtual_temperature_tendency=virtual_temperature_tendency, exner=exner, - exner_tendency=exner_tendency, + tracers=tracers, + temperature=diagnostic_state.temperature, horizontal_start=start_cell_nudging, horizontal_end=end_cell_local, vertical_start=vertical_params.kstart_moist, @@ -335,13 +309,11 @@ def test_diagnostic_update_after_saturation_adjustement( offset_provider={}, ) - updated_exner = exner.asnumpy() + exner_tendency.asnumpy() * dtime - - diagnose_surface_pressure.diagnose_surface_pressure.with_backend(backend)( - gtx.as_field((dims.CellDim, dims.KDim), updated_exner, allocator=backend), - gtx.as_field((dims.CellDim, dims.KDim), updated_virtual_temperature, allocator=backend), - metrics_savepoint.ddqz_z_full(), - diagnostic_state.pressure_ifc, + diagnostic_stencils.diagnose_surface_pressure.with_backend(backend)( + surface_pressure=diagnostic_state.pressure_at_half_levels, + exner=exner, + virtual_temperature=diagnostic_state.virtual_temperature, + ddqz_z_full=metrics_savepoint.ddqz_z_full(), horizontal_start=start_cell_nudging, horizontal_end=end_cell_local, vertical_start=icon_grid.num_levels, @@ -349,12 +321,12 @@ def test_diagnostic_update_after_saturation_adjustement( offset_provider={"Koff": dims.KDim}, ) - diagnose_pressure.diagnose_pressure.with_backend(backend)( - metrics_savepoint.ddqz_z_full(), - gtx.as_field((dims.CellDim, dims.KDim), updated_virtual_temperature, allocator=backend), - diagnostic_state.surface_pressure, - diagnostic_state.pressure, - diagnostic_state.pressure_ifc, + diagnostic_stencils.diagnose_pressure.with_backend(backend)( + pressure=diagnostic_state.pressure, + pressure_at_half_levels=diagnostic_state.pressure_at_half_levels, + surface_pressure=diagnostic_state.surface_pressure, + virtual_temperature=diagnostic_state.virtual_temperature, + ddqz_z_full=metrics_savepoint.ddqz_z_full(), horizontal_start=start_cell_nudging, horizontal_end=end_cell_local, vertical_start=gtx.int32(0), @@ -363,12 +335,12 @@ def test_diagnostic_update_after_saturation_adjustement( ) assert test_utils.dallclose( - updated_virtual_temperature, + diagnostic_state.virtual_temperature.asnumpy(), satad_exit.virtual_temperature().asnumpy(), atol=1.0e-13, ) assert test_utils.dallclose( - updated_exner, + exner.asnumpy(), satad_exit.exner().asnumpy(), atol=1.0e-13, ) @@ -378,7 +350,7 @@ def test_diagnostic_update_after_saturation_adjustement( atol=1.0e-13, ) assert test_utils.dallclose( - diagnostic_state.pressure_ifc.asnumpy(), + diagnostic_state.pressure_at_half_levels.asnumpy(), satad_exit.pressure_ifc().asnumpy(), atol=1.0e-13, ) diff --git a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_pressure.py b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_pressure.py index 2f0471e4fe..c1c6276b80 100644 --- a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_pressure.py +++ b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_pressure.py @@ -12,17 +12,15 @@ import pytest from icon4py.model.common import constants as phy_const, dimension as dims, type_alias as ta -from icon4py.model.common.diagnostic_calculations.stencils.diagnose_pressure import ( - diagnose_pressure, -) +from icon4py.model.common.diagnostic_calculations import stencils as diagnostic_stencils from icon4py.model.common.grid import base from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.testing import stencil_tests class TestDiagnosePressure(stencil_tests.StencilTest): - PROGRAM = diagnose_pressure - OUTPUTS = ("pressure", "pressure_ifc") + PROGRAM = diagnostic_stencils.diagnose_pressure + OUTPUTS = ("pressure", "pressure_at_half_levels") @staticmethod def reference( @@ -32,24 +30,24 @@ def reference( ddqz_z_full: np.ndarray, **kwargs: Any, ) -> dict: - pressure_ifc = np.zeros_like(virtual_temperature) + pressure_at_half_levels = np.zeros_like(virtual_temperature) pressure = np.zeros_like(virtual_temperature) ground_level = virtual_temperature.shape[1] - 1 - pressure_ifc[:, ground_level] = surface_pressure * np.exp( + pressure_at_half_levels[:, ground_level] = surface_pressure * np.exp( -phy_const.GRAV_O_RD * ddqz_z_full[:, ground_level] / virtual_temperature[:, ground_level] ) - pressure[:, ground_level] = np.sqrt(pressure_ifc[:, ground_level] * surface_pressure) + pressure[:, ground_level] = np.sqrt(pressure_at_half_levels[:, ground_level] * surface_pressure) for k in range(ground_level - 1, -1, -1): - pressure_ifc[:, k] = pressure_ifc[:, k + 1] * np.exp( + pressure_at_half_levels[:, k] = pressure_at_half_levels[:, k + 1] * np.exp( -phy_const.GRAV_O_RD * ddqz_z_full[:, k] / virtual_temperature[:, k] ) - pressure[:, k] = np.sqrt(pressure_ifc[:, k] * pressure_ifc[:, k + 1]) + pressure[:, k] = np.sqrt(pressure_at_half_levels[:, k] * pressure_at_half_levels[:, k + 1]) return dict( pressure=pressure, - pressure_ifc=pressure_ifc, + pressure_at_half_levels=pressure_at_half_levels, ) @pytest.fixture @@ -62,14 +60,14 @@ def input_data(self, grid: base.Grid) -> dict: ) surface_pressure = data_alloc.random_field(grid, dims.CellDim, low=1.0, dtype=ta.wpfloat) pressure = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat) - pressure_ifc = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat) + pressure_at_half_levels = data_alloc.zero_field(grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat) return dict( ddqz_z_full=ddqz_z_full, virtual_temperature=virtual_temperature, surface_pressure=surface_pressure, pressure=pressure, - pressure_ifc=pressure_ifc, + pressure_at_half_levels=pressure_at_half_levels, horizontal_start=gtx.int32(0), horizontal_end=gtx.int32(grid.num_cells), vertical_start=gtx.int32(0), diff --git a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_surface_pressure.py b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_surface_pressure.py index 55e0967f6a..182de48ce7 100644 --- a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_surface_pressure.py +++ b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_surface_pressure.py @@ -12,16 +12,14 @@ import pytest from icon4py.model.common import constants as phy_const, dimension as dims, type_alias as ta -from icon4py.model.common.diagnostic_calculations.stencils.diagnose_surface_pressure import ( - diagnose_surface_pressure, -) +from icon4py.model.common.diagnostic_calculations import stencils as diagnostic_stencils from icon4py.model.common.grid import base from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.testing import stencil_tests class TestDiagnoseSurfacePressure(stencil_tests.StencilTest): - PROGRAM = diagnose_surface_pressure + PROGRAM = diagnostic_stencils.diagnose_surface_pressure OUTPUTS = ("surface_pressure",) @staticmethod diff --git a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_temperature.py b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_temperature.py index 1e03d6b648..c7e0d97237 100644 --- a/model/common/tests/common/interpolation/stencil_tests/test_diagnose_temperature.py +++ b/model/common/tests/common/interpolation/stencil_tests/test_diagnose_temperature.py @@ -12,16 +12,14 @@ import pytest from icon4py.model.common import constants as phy_const, dimension as dims, type_alias as ta -from icon4py.model.common.diagnostic_calculations.stencils.diagnose_temperature import ( - diagnose_virtual_temperature_and_temperature, -) +from icon4py.model.common.diagnostic_calculations import stencils as diagnostic_stencils from icon4py.model.common.grid import base from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.testing import stencil_tests class TestDiagnoseTemperature(stencil_tests.StencilTest): - PROGRAM = diagnose_virtual_temperature_and_temperature + PROGRAM = diagnostic_stencils.diagnose_virtual_temperature_and_temperature_from_exner OUTPUTS = ("virtual_temperature", "temperature") @staticmethod diff --git a/model/common/tests/common/interpolation/stencil_tests/test_edge_2_cell_vector_rbf_interpolation.py b/model/common/tests/common/interpolation/stencil_tests/test_edge_2_cell_vector_rbf_interpolation.py index a4dea61823..a3da04d76c 100644 --- a/model/common/tests/common/interpolation/stencil_tests/test_edge_2_cell_vector_rbf_interpolation.py +++ b/model/common/tests/common/interpolation/stencil_tests/test_edge_2_cell_vector_rbf_interpolation.py @@ -13,8 +13,8 @@ from icon4py.model.common import dimension as dims, type_alias as ta from icon4py.model.common.grid import base -from icon4py.model.common.interpolation.stencils.edge_2_cell_vector_rbf_interpolation import ( - edge_2_cell_vector_rbf_interpolation, +from icon4py.model.common.interpolation.stencils.compute_edge_2_cell_vector_interpolation import ( + compute_edge_2_cell_vector_interpolation, ) from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.testing import stencil_tests @@ -22,7 +22,7 @@ @pytest.mark.skip_value_error class TestEdge2CellVectorRBFInterpolation(stencil_tests.StencilTest): - PROGRAM = edge_2_cell_vector_rbf_interpolation + PROGRAM = compute_edge_2_cell_vector_interpolation OUTPUTS = ("p_u_out", "p_v_out") @staticmethod diff --git a/model/driver/src/icon4py/model/driver/testcases/gauss3d.py b/model/driver/src/icon4py/model/driver/testcases/gauss3d.py index d308283548..c585707af6 100644 --- a/model/driver/src/icon4py/model/driver/testcases/gauss3d.py +++ b/model/driver/src/icon4py/model/driver/testcases/gauss3d.py @@ -22,7 +22,7 @@ from icon4py.model.common.grid import horizontal as h_grid, icon as icon_grid, states as grid_states from icon4py.model.common.interpolation.stencils import ( cell_2_edge_interpolation, - edge_2_cell_vector_rbf_interpolation, + compute_edge_2_cell_vector_interpolation, ) from icon4py.model.common.states import ( diagnostic_state as diagnostics, @@ -213,7 +213,9 @@ def model_initialization_gauss3d( # noqa: PLR0915 [too-many-statements] allocator=allocator, ) - edge_2_cell_vector_rbf_interpolation.edge_2_cell_vector_rbf_interpolation.with_backend(backend)( + compute_edge_2_cell_vector_interpolation.compute_edge_2_cell_vector_interpolation.with_backend( + backend + )( vn, rbf_vec_coeff_c1, rbf_vec_coeff_c2, diff --git a/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py b/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py index 336b3a4d69..2ee1a13ce7 100644 --- a/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py +++ b/model/driver/src/icon4py/model/driver/testcases/jablonowski_williamson.py @@ -24,7 +24,7 @@ from icon4py.model.common.grid import horizontal as h_grid, icon as icon_grid, states as grid_states from icon4py.model.common.interpolation.stencils import ( cell_2_edge_interpolation, - edge_2_cell_vector_rbf_interpolation, + compute_edge_2_cell_vector_interpolation, ) from icon4py.model.common.states import ( diagnostic_state as diagnostics, @@ -299,7 +299,7 @@ def model_initialization_jabw( # noqa: PLR0915 [too-many-statements] allocator=allocator, ) - edge_2_cell_vector_rbf_interpolation.edge_2_cell_vector_rbf_interpolation.with_backend(backend)( + compute_edge_2_cell_vector_interpolation.compute_edge_2_cell_vector_interpolation.with_backend(backend)( vn, rbf_vec_coeff_c1, rbf_vec_coeff_c2, diff --git a/model/standalone_driver/src/icon4py/model/standalone_driver/driver_states.py b/model/standalone_driver/src/icon4py/model/standalone_driver/driver_states.py index 7d5ef5892a..6c863b8df6 100644 --- a/model/standalone_driver/src/icon4py/model/standalone_driver/driver_states.py +++ b/model/standalone_driver/src/icon4py/model/standalone_driver/driver_states.py @@ -24,10 +24,7 @@ from icon4py.model.common.grid import geometry as grid_geometry from icon4py.model.common.interpolation import interpolation_factory from icon4py.model.common.metrics import metrics_factory -from icon4py.model.common.states import ( - diagnostic_state as diagnostics, - prognostic_state as prognostics, -) +from icon4py.model.common.states import diagnostic_state, prognostic_state, tracer_state from icon4py.model.standalone_driver import config as driver_config @@ -66,8 +63,9 @@ class DriverStates(NamedTuple): tracer_advection_diagnostic: advection_states.AdvectionDiagnosticState prep_tracer_advection_prognostic: advection_states.AdvectionPrepAdvState diffusion_diagnostic: diffusion_states.DiffusionDiagnosticState - prognostics: common_utils.TimeStepPair[prognostics.PrognosticState] - diagnostic: diagnostics.DiagnosticState + prognostics: common_utils.TimeStepPair[prognostic_state.PrognosticState] + diagnostic: diagnostic_state.DiagnosticState + tracers: common_utils.TimeStepPair[tracer_state.TracerState] @dataclasses.dataclass diff --git a/model/standalone_driver/src/icon4py/model/standalone_driver/physics_driver.py b/model/standalone_driver/src/icon4py/model/standalone_driver/physics_driver.py new file mode 100644 index 0000000000..266db3b34f --- /dev/null +++ b/model/standalone_driver/src/icon4py/model/standalone_driver/physics_driver.py @@ -0,0 +1,374 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +import logging +import types + +import gt4py.next as gtx + +from icon4py.model.atmosphere.subgrid_scale_physics.microphysics import ( + saturation_adjustment as satad, + single_moment_six_class_gscp_graupel as graupel, +) +from icon4py.model.common import ( + dimension as dims, + field_type_aliases as fa, + model_backends, + type_alias as ta, +) +from icon4py.model.common.decomposition import definitions as decomposition +from icon4py.model.common.diagnostic_calculations import stencils as diagnostic_stencils +from icon4py.model.common.grid import ( + geometry_attributes as geo_attrs, + horizontal as h_grid, + icon as icon_grid, + vertical as v_grid, +) +from icon4py.model.common.interpolation import interpolation_attributes as interp_attrs +from icon4py.model.common.interpolation.stencils import compute_edge_2_cell_vector_interpolation +from icon4py.model.common.metrics import metrics_attributes as metric_attrs +from icon4py.model.common.states import diagnostic_state, prognostic_state, tracer_state +from icon4py.model.common.utils import data_allocation as data_alloc +from icon4py.model.standalone_driver import driver_states + + +log = logging.getLogger(__name__) + + +class PhysicsDriver: + def __init__( + self, + grid: icon_grid.IconGrid, + vertical_grid: v_grid.VerticalGrid, + static_field_factories: driver_states.StaticFieldFactories, + saturation_adjustment: satad.SaturationAdjustment, + microphysics: graupel.SingleMomentSixClassIconGraupel, + backend: model_backends.BackendLike, + exchange: decomposition.ExchangeRuntime = decomposition.single_node_default, + ): + self.vertical_grid = vertical_grid + self.grid = grid + self.static_field_factories = static_field_factories + self.saturation_adjustment = saturation_adjustment + self.microphysics = microphysics + self._exchange = exchange + self.backend = backend + + @functools.cached_property + def _allocator(self) -> gtx.typing.Backend: + return model_backends.get_allocator(self.backend) + + @functools.cached_property + def _xp(self) -> types.ModuleType: + return data_alloc.import_array_ns(self._allocator) + + def _local_fields(self) -> None: + self.temperature_tendency = data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ) + self.u_tendency = data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ) + self.v_tendency = data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ) + self.tracer_tendency = tracer_state.TracerStateTendency( + qv_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + qc_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + qi_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + qr_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + qs_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + qg_tendency=data_alloc.zero_field( + self.grid, dims.CellDim, dims.KDim, dtype=ta.wpfloat, allocator=self._allocator + ), + ) + + def __call__( + self, + prognostic: prognostic_state.PrognosticState, + diagnostic: diagnostic_state.DiagnosticState, + tracers: tracer_state.TracerState, + perturbed_exner: fa.CellKField[ta.wpfloat], + dtime: ta.wpfloat, + ) -> None: + saved_exner = data_alloc.as_field( + prognostic.exner, allocator=self._allocator + ) # saved_exner, 1: min_rlcell + + compute_edge_2_cell_vector_interpolation.compute_edge_2_cell_vector_interpolation.with_backend( + self.backend + )( + p_e_in=prognostic.vn, + ptr_coeff_1=self.static_field_factories.interpolation_field_source.get( + interp_attrs.RBF_VEC_COEFF_C1 + ), + ptr_coeff_2=self.static_field_factories.interpolation_field_source.get( + interp_attrs.RBF_VEC_COEFF_C2 + ), + p_u_out=diagnostic.u, + p_v_out=diagnostic.v, + horizontal_start=1, + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=0, + vertical_end=self.grid.num_levels, + offset_provider=self.grid.connectivities, + ) # 1: min_rlcell_int + + diagnostic_stencils.diagnose_virtual_temperature_and_temperature_from_exner.with_backend( + self.backend + )( + virtual_temperature=diagnostic.virtual_temperature, + temperature=diagnostic.temperature, + tracers=tracers, + theta_v=prognostic.theta_v, + exner=prognostic.exner, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + offset_provider={}, + ) # from kmoist, grf_bdywidth_c+1: min_rlcell_int + + self.saturation_adjustment.run( + temperature_tendency=self.temperature_tendency, + qv_tendency=self.tracer_tendency.qv_tendency, + qc_tendency=self.tracer_tendency.qc_tendency, + rho=prognostic.rho, + temperature=diagnostic.temperature, + qv=tracers.qv, + qc=tracers.qc, + dtime=dtime, + ) # from kmoist, grf_bdywidth_c+1: min_rlcell_int + diagnostic_stencils.update_satad_output_from_tendency.with_backend(self.backend)( + temperature=diagnostic.temperature, + qv=tracers.qv, + qc=tracers.qc, + temperature_tendency=self.temperature_tendency, + qv_tendency=self.tracer_tendency.qv_tendency, + qc_tendency=self.tracer_tendency.qc_tendency, + dtime=dtime, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + ) + + diagnostic_stencils.diagnose_virtual_temperature_and_exner.with_backend(self.backend)( + virtual_temperature=diagnostic.virtual_temperature, + exner=prognostic.exner, + tracers=tracers, + temperature=diagnostic.temperature, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + offset_provider={}, + ) # from kmoist, grf_bdywidth_c+1: min_rlcell_int + diagnostic_stencils.diagnose_surface_pressure.with_backend(self.backend)( + surface_pressure=diagnostic.pressure_at_half_levels, + exner=prognostic.exner, + virtual_temperature=diagnostic.virtual_temperature, + ddqz_z_full=self.static_field_factories.metrics_field_source.get( + metric_attrs.DDQZ_Z_FULL + ), + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.grid.num_levels, + vertical_end=self.grid.num_levels + 1, + offset_provider={"Koff": dims.KDim}, + ) + + diagnostic_stencils.diagnose_pressure.with_backend(self.backend)( + pressure=diagnostic.pressure, + pressure_at_half_levels=diagnostic.pressure_at_half_levels, + surface_pressure=diagnostic.surface_pressure, + virtual_temperature=diagnostic.virtual_temperature, + ddqz_z_full=self.static_field_factories.metrics_field_source.get( + metric_attrs.DDQZ_Z_FULL + ), + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + offset_provider={}, + ) # from kmoist, grf_bdywidth_c+1: min_rlcell_int + + # TODO (Chia Rui): simple_surface() # only for qv_s, grf_bdywidth_c+1: min_rlcell_int + # TODO (Chia Rui): turbulence() # grf_bdywidth_c+1: min_rlcell_int, NOT PORTED + + self.microphysics.run( + qv_tendency=self.tracer_tendency.qv_tendency, + qc_tendency=self.tracer_tendency.qc_tendency, + qi_tendency=self.tracer_tendency.qi_tendency, + qr_tendency=self.tracer_tendency.qr_tendency, + qs_tendency=self.tracer_tendency.qs_tendency, + qg_tendency=self.tracer_tendency.qg_tendency, + temperature_tendency=self.temperature_tendency, + qv=tracers.qv, + qc=tracers.qc, + qi=tracers.qi, + qr=tracers.qr, + qs=tracers.qs, + qg=tracers.qg, + rho=prognostic.rho, + temperature=diagnostic.temperature, + pressure=diagnostic.pressure, + dtime=dtime, + ) + diagnostic_stencils.update_microphysics_output_from_tendency.with_backend(self.backend)( + temperature=diagnostic.temperature, + tracers=tracers, + temperature_tendency=self.temperature_tendency, + tracer_tendency=self.tracer_tendency, + dtime=dtime, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + ) + + self.saturation_adjustment.run( + temperature_tendency=self.temperature_tendency, + qv_tendency=self.tracer_tendency.qv_tendency, + qc_tendency=self.tracer_tendency.qc_tendency, + rho=prognostic.rho, + temperature=diagnostic.temperature, + qv=tracers.qv, + qc=tracers.qc, + dtime=dtime, + ) + diagnostic_stencils.update_satad_output_from_tendency.with_backend(self.backend)( + temperature=diagnostic.temperature, + qv=tracers.qv, + qc=tracers.qc, + temperature_tendency=self.temperature_tendency, + qv_tendency=self.tracer_tendency.qv_tendency, + qc_tendency=self.tracer_tendency.qc_tendency, + dtime=dtime, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=self.vertical_grid.kstart_moist, + vertical_end=self.grid.num_levels, + ) + + diagnostic_stencils.diagnose_exner_and_theta_v_from_virtual_temperature.with_backend( + self.backend + )( + virtual_temperature=diagnostic.virtual_temperature, + exner=prognostic.exner, + perturbed_exner=perturbed_exner, + theta_v=prognostic.theta_v, + tracers=tracers, + temperature=diagnostic.temperature, + rho=prognostic.rho, + previous_exner=saved_exner, + horizontal_start=self.grid.cell_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.LOCAL], + vertical_start=0, + vertical_end=self.grid.num_levels, + offset_provider={}, + ) + + # TODO (Chia Rui): add diagnose_pressure() here when turbulence is ready # grf_bdywidth_c+1: min_rlcell_int + # TODO (Chia Rui): surface_transfer() # grf_bdywidth_c+1: min_rlcell_int, NOT PORTED YET + + # TODO (Chia Rui): (and w if diffusion is applied to w) + self._exchange.exchange( + dims.CellDim, + diagnostic.virtual_temperature, + perturbed_exner, + tracers.qv, + tracers.qc, + tracers.qr, + tracers.qi, + tracers.qs, + tracers.qg, + stream=decomposition.DEFAULT_STREAM, + ) + # TODO (Chia Rui): halo exchange, including ddt_u_turb and ddt_v_turb + diagnostic_stencils.update_exner_and_theta_v_from_virtual_temperature_in_halo.with_backend( + self.backend + )( + exner=prognostic.exner, + theta_v=prognostic.theta_v, + rho=prognostic.rho, + virtual_temperature=diagnostic.virtual_temperature, + mask_prog_halo_c=self.static_field_factories.metrics_field_source.get( + metric_attrs.MASK_PROG_HALO_C + ), + horizontal_start=self.grid.cell_start_index[h_grid.Zone.HALO], + horizontal_end=self.grid.cell_end_index[h_grid.Zone.END], + vertical_start=0, + vertical_end=self.grid.num_levels, + ) # min_rlcell_int-1: min_rlcell_int + diagnostic_stencils.update_vn_from_u_v_tendencies.with_backend(self.backend)( + vn=prognostic.vn, + u_tendency=self.u_tendency, + v_tendency=self.v_tendency, + dt=dtime, + c_lin_e=self.static_field_factories.interpolation_field_source.get( + interp_attrs.C_LIN_E + ), + primal_normal_cell_x=self.static_field_factories.geometry_field_source.get( + geo_attrs.EDGE_NORMAL_CELL_U + ), + primal_normal_cell_y=self.static_field_factories.geometry_field_source.get( + geo_attrs.EDGE_NORMAL_CELL_V + ), + horizontal_start=self.grid.edge_start_index[h_grid.Zone.NUDGING], + horizontal_end=self.grid.edge_end_index[h_grid.Zone.LOCAL], + vertical_start=0, + vertical_end=self.grid.num_levels, + offset_provider=self.grid.connectivities, + ) # grf_bdywidth_e+1: min_rledge_int + + +def initialize_physics_driver( + grid: icon_grid.IconGrid, + vertical_grid: v_grid.VerticalGrid, + static_field_factories: driver_states.StaticFieldFactories, + backend: model_backends.BackendLike, +) -> PhysicsDriver: + saturation_adjustment = satad.SaturationAdjustment( + config=satad.SaturationAdjustmentConfig(), + grid=grid, + vertical_params=vertical_grid, + metric_state=satad.MetricStateSaturationAdjustment( + ddqz_z_full=static_field_factories.metrics_field_source.get(metric_attrs.DDQZ_Z_FULL) + ), + backend=backend, + ) + microphysics = graupel.SingleMomentSixClassIconGraupel( + config=graupel.SingleMomentSixClassIconGraupelConfig(), + grid=grid, + vertical_params=vertical_grid, + metric_state=graupel.MetricStateIconGraupel( + ddqz_z_full=static_field_factories.metrics_field_source.get(metric_attrs.DDQZ_Z_FULL) + ), + backend=backend, + ) + physics_driver = PhysicsDriver( + grid=grid, + vertical_grid=vertical_grid, + static_field_factories=static_field_factories, + saturation_adjustment=saturation_adjustment, + microphysics=microphysics, + backend=backend, + ) + return physics_driver diff --git a/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py b/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py index 276be59f83..646d43625e 100644 --- a/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py +++ b/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py @@ -18,8 +18,8 @@ from gt4py.next.instrumentation import metrics as gtx_metrics import icon4py.model.common.utils as common_utils -from icon4py.model.atmosphere.advection import advection, advection_states -from icon4py.model.atmosphere.diffusion import diffusion, diffusion_states +from icon4py.model.atmosphere.advection import advection +from icon4py.model.atmosphere.diffusion import diffusion from icon4py.model.atmosphere.dycore import dycore_states, solve_nonhydro as solve_nh from icon4py.model.common import dimension as dims, model_backends, model_options, type_alias as ta from icon4py.model.common.decomposition import definitions as decomposition_defs @@ -34,6 +34,7 @@ driver_constants, driver_states, driver_utils, + physics_driver, ) @@ -49,8 +50,11 @@ def __init__( static_field_factories: driver_states.StaticFieldFactories, diffusion_granule: diffusion.Diffusion, solve_nonhydro_granule: solve_nh.SolveNonhydro, + physics: physics_driver.PhysicsDriver, vertical_grid_config: v_grid.VerticalGridConfig, tracer_advection_granule: advection.Advection, + global_reductions: decomposition_defs.Reductions = decomposition_defs.single_node_reductions, + exchange: decomposition_defs.ExchangeRuntime = decomposition_defs.single_node_default, ): self.config = config self.backend = backend @@ -58,9 +62,12 @@ def __init__( self.static_field_factories = static_field_factories self.diffusion = diffusion_granule self.solve_nonhydro = solve_nonhydro_granule + self.physics = physics self.vertical_grid_config = vertical_grid_config self.model_time_variables = driver_states.ModelTimeVariables(config=config) self.tracer_advection = tracer_advection_granule + self._global_reductions = global_reductions + self._exchange = exchange self.timer_collection = driver_states.TimerCollection( [timer.value for timer in driver_states.DriverTimers] ) @@ -94,13 +101,6 @@ def time_integration( ds: driver_states.DriverStates, do_prep_adv: bool, ) -> None: - diffusion_diagnostic_state = ds.diffusion_diagnostic - solve_nonhydro_diagnostic_state = ds.solve_nonhydro_diagnostic - tracer_advection_diagnostic_state = ds.tracer_advection_diagnostic - prognostic_states = ds.prognostics - prep_adv = ds.prep_advection_prognostic - tracer_prep_adv = ds.prep_tracer_advection_prognostic - log.debug( f"starting time loop for dtime = {self.model_time_variables.dtime_in_seconds} s, substep_timestep = {self.model_time_variables.substep_timestep} s, n_timesteps = {self.model_time_variables.n_time_steps}" ) @@ -125,23 +125,18 @@ def time_integration( self.model_time_variables.next_simulation_date() self._integrate_one_time_step( - diffusion_diagnostic_state, - solve_nonhydro_diagnostic_state, - tracer_advection_diagnostic_state, - prognostic_states, - prep_adv, + ds, do_prep_adv, - tracer_prep_adv, ) device_utils.sync(self.backend) self.model_time_variables.is_first_step_in_simulation = False - self._adjust_ndyn_substeps_var(solve_nonhydro_diagnostic_state) + self._adjust_ndyn_substeps_var(ds.solve_nonhydro_diagnostic) # TODO(OngChia): simple IO enough for JW test - self._compute_mean_at_final_time_step(prognostic_states.current) + self._compute_mean_at_final_time_step(ds.prognostics.current) self.timer_collection.show_timer_report() if ( @@ -153,19 +148,14 @@ def time_integration( def _integrate_one_time_step( self, - diffusion_diagnostic_state: diffusion_states.DiffusionDiagnosticState, - solve_nonhydro_diagnostic_state: dycore_states.DiagnosticStateNonHydro, - tracer_advection_diagnostic_state: advection_states.AdvectionDiagnosticState, - prognostic_states: common_utils.TimeStepPair[prognostics.PrognosticState], - prep_adv: dycore_states.PrepAdvection, + ds: driver_states.DriverStates, do_prep_adv: bool, - tracer_prep_adv: advection_states.AdvectionPrepAdvState, ) -> None: log.debug(f"Running {self.solve_nonhydro.__class__}") self._do_dyn_substepping( - solve_nonhydro_diagnostic_state, - prognostic_states, - prep_adv, + ds.solve_nonhydro_diagnostic, + ds.prognostics, + ds.prep_advection_prognostic, do_prep_adv, ) @@ -178,24 +168,33 @@ def _integrate_one_time_step( ) with timer_diffusion: self.diffusion.run( - diffusion_diagnostic_state, - prognostic_states.next, + ds.diffusion_diagnostic, + ds.prognostics.next, self.model_time_variables.dtime_in_seconds, ) - timer_diffusion.capture() # TODO(ricoh): [c34] optionally move the loop into the granule (for efficiency gains) # Precondition: passing data test with ntracer > 0 - for tracer_idx in range(self.config.ntracer): - self.tracer_advection.run( - diagnostic_state=tracer_advection_diagnostic_state, - prep_adv=tracer_prep_adv, - p_tracer_now=prognostic_states.current.tracer[tracer_idx], - p_tracer_new=prognostic_states.next.tracer[tracer_idx], - dtime=self.model_time_variables.dtime_in_seconds, - ) + if self.config.ntracer > 0: + for tracer_now, tracer_next in zip(ds.tracers.current, ds.tracers.next): + self.tracer_advection.run( + diagnostic_state=ds.tracer_advection_diagnostic, + prep_adv=ds.prep_tracer_advection_prognostic, + p_tracer_now=tracer_now, + p_tracer_new=tracer_next, + dtime=self.model_time_variables.dtime_in_seconds, + ) + ds.tracers.swap() + + ds.prognostics.swap() - prognostic_states.swap() + self.physics( + prognostic=ds.prognostics.current, + diagnostic=ds.diagnostic, + tracers=ds.tracers.current, + perturbed_exner=ds.solve_nonhydro_diagnostic.perturbed_exner_at_cells_on_model_levels, + dtime=self.model_time_variables.dtime_in_seconds, + ) def _update_time_levels_for_velocity_tendencies( self, @@ -259,20 +258,19 @@ def _do_dyn_substepping( at_initial_timestep=self.model_time_variables.is_first_step_in_simulation, ) - timer_solve_nh.start() - self.solve_nonhydro.time_step( - solve_nonhydro_diagnostic_state, - prognostic_states, - prep_adv=prep_adv, - second_order_divdamp_factor=self._update_spinup_second_order_divergence_damping(), - dtime=self.model_time_variables.substep_timestep, - ndyn_substeps_var=self.model_time_variables.ndyn_substeps_var, - at_initial_timestep=self.model_time_variables.is_first_step_in_simulation, - lprep_adv=do_prep_adv, - at_first_substep=self._is_first_substep(dyn_substep), - at_last_substep=self._is_last_substep(dyn_substep), - ) - timer_solve_nh.capture() + with timer_solve_nh: + self.solve_nonhydro.time_step( + solve_nonhydro_diagnostic_state, + prognostic_states, + prep_adv=prep_adv, + second_order_divdamp_factor=self._update_spinup_second_order_divergence_damping(), + dtime=self.model_time_variables.substep_timestep, + ndyn_substeps_var=self.model_time_variables.ndyn_substeps_var, + at_initial_timestep=self.model_time_variables.is_first_step_in_simulation, + lprep_adv=do_prep_adv, + at_first_substep=self._is_first_substep(dyn_substep), + at_last_substep=self._is_last_substep(dyn_substep), + ) if not self._is_last_substep(dyn_substep): prognostic_states.swap() @@ -288,6 +286,12 @@ def _adjust_ndyn_substeps_var( solve_nonhydro_diagnostic_state: dycore_states.DiagnosticStateNonHydro, ) -> None: # TODO (Chia Rui): perform a global max operation in multinode run + """ + # global_max_vertical_cfl = self._global_reductions.max( + # buffer=solve_nonhydro_diagnostic_state.max_vertical_cfl, + # array_ns=self._xp, + # ) + """ global_max_vertical_cfl = solve_nonhydro_diagnostic_state.max_vertical_cfl[()] if ( @@ -664,6 +668,12 @@ def initialize_driver( ), backend=backend, ) + physics = physics_driver.initialize_physics_driver( + grid=grid_manager.grid, + vertical_grid=vertical_grid, + static_field_factories=static_field_factories, + backend=backend, + ) icon4py_driver = Icon4pyDriver( config=driver_config, backend=backend, @@ -671,8 +681,10 @@ def initialize_driver( static_field_factories=static_field_factories, diffusion_granule=diffusion_granule, solve_nonhydro_granule=solve_nonhydro_granule, + physics=physics, vertical_grid_config=vertical_grid_config, tracer_advection_granule=tracer_advection_granule, + global_reductions=global_reductions, ) return icon4py_driver diff --git a/model/standalone_driver/src/icon4py/model/standalone_driver/testcases/initial_condition.py b/model/standalone_driver/src/icon4py/model/standalone_driver/testcases/initial_condition.py index 02899c8b9a..be20d54736 100644 --- a/model/standalone_driver/src/icon4py/model/standalone_driver/testcases/initial_condition.py +++ b/model/standalone_driver/src/icon4py/model/standalone_driver/testcases/initial_condition.py @@ -31,13 +31,14 @@ from icon4py.model.common.interpolation import interpolation_attributes, interpolation_factory from icon4py.model.common.interpolation.stencils import ( cell_2_edge_interpolation, - edge_2_cell_vector_rbf_interpolation, + compute_edge_2_cell_vector_interpolation, ) from icon4py.model.common.math.stencils import generic_math_operations as gt4py_math_op from icon4py.model.common.metrics import metrics_attributes, metrics_factory from icon4py.model.common.states import ( diagnostic_state as diagnostics, prognostic_state as prognostics, + tracer_state as tracers, ) from icon4py.model.common.utils import data_allocation as data_alloc from icon4py.model.standalone_driver import driver_states @@ -148,7 +149,7 @@ def jablonowski_williamson( # noqa: PLR0915 [too-many-statements] eta_v_ndarray = eta_v.ndarray # set surface pressure - diagnostic_state.pressure_ifc.ndarray[:, -1] = p_sfc + diagnostic_state.pressure_at_half_levels.ndarray[:, -1] = p_sfc sin_lat = xp.sin(cell_lat) cos_lat = xp.cos(cell_lat) @@ -305,7 +306,9 @@ def jablonowski_williamson( # noqa: PLR0915 [too-many-statements] ) prognostic_states = common_utils.TimeStepPair(prognostic_state_now, prognostic_state_next) - edge_2_cell_vector_rbf_interpolation.edge_2_cell_vector_rbf_interpolation.with_backend(backend)( + compute_edge_2_cell_vector_interpolation.compute_edge_2_cell_vector_interpolation.with_backend( + backend + )( p_e_in=prognostic_states.current.vn, ptr_coeff_1=rbf_vec_coeff_c1, ptr_coeff_2=rbf_vec_coeff_c2, @@ -350,6 +353,19 @@ def jablonowski_williamson( # noqa: PLR0915 [too-many-statements] mass_flx_me=data_alloc.zero_field(grid, dims.EdgeDim, dims.KDim, allocator=allocator), mass_flx_ic=data_alloc.zero_field(grid, dims.CellDim, dims.KDim, allocator=allocator), ) + tracer_state_now = tracers.initialize_tracer_state( + grid=grid, + allocator=allocator, + ) + tracer_state_next = tracers.TracerState( + qv=data_alloc.as_field(tracer_state_now.qv, allocator=allocator), + qc=data_alloc.as_field(tracer_state_now.qc, allocator=allocator), + qr=data_alloc.as_field(tracer_state_now.qr, allocator=allocator), + qi=data_alloc.as_field(tracer_state_now.qi, allocator=allocator), + qs=data_alloc.as_field(tracer_state_now.qs, allocator=allocator), + qg=data_alloc.as_field(tracer_state_now.qg, allocator=allocator), + ) + tracer_states = common_utils.TimeStepPair(tracer_state_now, tracer_state_next) log.info("Initialization completed.") ds = driver_states.DriverStates( @@ -360,6 +376,7 @@ def jablonowski_williamson( # noqa: PLR0915 [too-many-statements] diffusion_diagnostic=diffusion_diagnostic_state, prognostics=prognostic_states, diagnostic=diagnostic_state, + tracers=tracer_states, ) return ds