Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dc52a12
bindings: move swap from Fortran to Python
havogt Mar 18, 2026
073f425
Apply suggestions from code review
havogt Mar 18, 2026
37c21b2
update references
havogt Mar 18, 2026
66c1bf1
Move vertoffset_gradp to Python
havogt Mar 19, 2026
6f91160
cleanup flip
havogt Mar 19, 2026
47e9c7a
fix references
havogt Mar 19, 2026
b085cb3
add tests and fix bugs
havogt Mar 19, 2026
6bfed5d
fixes
havogt Mar 20, 2026
491344c
fix test
havogt Mar 20, 2026
1a54450
Merge branch 'main' into blueline_swap_rbf_vec_coeff_e_and_v
havogt Mar 20, 2026
bc8a73f
fix test
havogt Mar 20, 2026
cd4b75c
fix tests
havogt Mar 20, 2026
fa8def5
Merge branch 'blueline_swap_rbf_vec_coeff_e_and_v' into tmp_vertoffse…
havogt Mar 20, 2026
2a23643
switch to serialized data v3
havogt Mar 20, 2026
742336f
Update base.yml
havogt Mar 20, 2026
64d3e59
resolve module name conflict
havogt Mar 20, 2026
96a90bc
Merge branch 'main' into tmp_vertoffset_gradp_and_dim_swap3
havogt Mar 20, 2026
9d865bd
limit test
havogt Mar 20, 2026
27106a3
Merge branch 'tmp_vertoffset_gradp_and_dim_swap3' of github.com:C2SM/…
havogt Mar 20, 2026
c6a8ac9
Merge remote-tracking branch 'upstream/main' into tmp_vertoffset_grad…
havogt Mar 23, 2026
f6ba342
switch on/off download
havogt Mar 23, 2026
2af2071
Merge remote-tracking branch 'upstream/main' into tmp_vertoffset_grad…
havogt Mar 23, 2026
7bfbc6e
add comments and disable data download
havogt Mar 23, 2026
92b6460
fix verification range
havogt Mar 23, 2026
8162aef
fix accidental tuple
havogt Mar 23, 2026
6ff540c
add missing exchange
havogt Mar 23, 2026
5992e38
hack: index_field conversion
havogt Mar 23, 2026
ef3aa29
cleanup and document index2offset
havogt Mar 24, 2026
b34fe50
Merge branch 'main' into tmp_vertoffset_gradp_and_dim_swap3
havogt Mar 24, 2026
5382941
cleanup exchange in zdiff_gradp
havogt Mar 24, 2026
3d81bbd
renames
havogt Mar 24, 2026
404bf7d
Apply suggestions from code review
havogt Mar 24, 2026
de1c4aa
address review comments
havogt Mar 24, 2026
378234f
slightly improve dummy_exchange
havogt Mar 24, 2026
dff89b7
add doctest to index2offset
havogt Mar 24, 2026
0057ef2
Merge branch 'tmp_vertoffset_gradp_and_dim_swap3' of github.com:C2SM/…
havogt Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def test_compute_perturbed_quantities_and_interpolation(
reference_theta_at_cells_on_half_levels = metrics_savepoint.theta_ref_ic()
d2dexdz2_fac1_mc = metrics_savepoint.d2dexdz2_fac1_mc()
d2dexdz2_fac2_mc = metrics_savepoint.d2dexdz2_fac2_mc()
wgtfacq_c = metrics_savepoint.wgtfacq_c_dsl()
wgtfacq_c = metrics_savepoint.wgtfacq_c()
wgtfac_c = metrics_savepoint.wgtfac_c()
exner_w_explicit_weight_parameter = metrics_savepoint.vwind_expl_wgt()
ddz_of_reference_exner_at_cells_on_half_levels = metrics_savepoint.d_exner_dz_ref_ic()
Expand Down Expand Up @@ -1167,16 +1167,21 @@ def test_compute_perturbed_quantities_and_interpolation(
perturbed_theta_v_at_cells_on_model_levels.asnumpy(), z_rth_pr_2_ref.asnumpy()
)
assert test_utils.dallclose(
temporal_extrapolation_of_perturbed_exner.asnumpy(), z_exner_ex_pr_ref.asnumpy()
temporal_extrapolation_of_perturbed_exner.asnumpy()[
start_cell_lateral_boundary_level_3:end_cell_halo, :
],
z_exner_ex_pr_ref.asnumpy()[start_cell_lateral_boundary_level_3:end_cell_halo, :],
)
assert test_utils.dallclose(
perturbed_exner_at_cells_on_model_levels.asnumpy(), exner_pr_ref.asnumpy()
)
assert test_utils.dallclose(rho_at_cells_on_half_levels.asnumpy(), rho_ic_ref.asnumpy())

assert test_utils.dallclose(
exner_at_cells_on_half_levels.asnumpy()[:, nflatlev:],
z_exner_ic_ref.asnumpy()[:, nflatlev:],
exner_at_cells_on_half_levels.asnumpy()[
start_cell_lateral_boundary_level_3:end_cell_halo, nflatlev:
],
z_exner_ic_ref.asnumpy()[start_cell_lateral_boundary_level_3:end_cell_halo, nflatlev:],
rtol=1e-11,
)

Expand Down Expand Up @@ -1768,7 +1773,7 @@ def test_compute_horizontal_velocity_quantities_and_fluxes(
ddxn_z_full = metrics_savepoint.ddxn_z_full()
ddxt_z_full = metrics_savepoint.ddxt_z_full()
wgtfac_e = metrics_savepoint.wgtfac_e()
wgtfacq_e = metrics_savepoint.wgtfacq_e_dsl()
wgtfacq_e = metrics_savepoint.wgtfacq_e()
rbf_vec_coeff_e = interpolation_savepoint.rbf_vec_coeff_e()
geofac_grdiv = interpolation_savepoint.geofac_grdiv()
nflatlev = vertical_params.nflatlev
Expand Down Expand Up @@ -2152,7 +2157,7 @@ def test_vertically_implicit_solver_at_predictor_step(
reference_exner_at_cells_on_model_levels=metrics_savepoint.exner_ref_mc(),
e_bln_c_s=interpolation_savepoint.e_bln_c_s(),
wgtfac_c=metrics_savepoint.wgtfac_c(),
wgtfacq_c=metrics_savepoint.wgtfacq_c_dsl(),
wgtfacq_c=metrics_savepoint.wgtfacq_c(),
iau_wgt_dyn=iau_wgt_dyn,
dtime=savepoint_nonhydro_init.get_metadata("dtime").get("dtime"),
is_iau_active=is_iau_active,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_compute_diagnostics_from_normal_wind(
ddxn_z_full = metrics_savepoint.ddxn_z_full()
ddxt_z_full = metrics_savepoint.ddxt_z_full()
contravariant_correction_at_edges_on_model_levels = savepoint_velocity_init.z_w_concorr_me()
wgtfacq_e = metrics_savepoint.wgtfacq_e_dsl()
wgtfacq_e = metrics_savepoint.wgtfacq_e()
nflatlev = grid_savepoint.nflatlev()
c_intp = interpolation_savepoint.c_intp()
inv_dual_edge_length = grid_savepoint.inv_dual_edge_length()
Expand Down
4 changes: 2 additions & 2 deletions model/atmosphere/dycore/tests/dycore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def construct_metric_state(
time_extrapolation_parameter_for_exner=metrics_savepoint.exner_exfac(),
reference_exner_at_cells_on_model_levels=metrics_savepoint.exner_ref_mc(),
wgtfac_c=metrics_savepoint.wgtfac_c(),
wgtfacq_c=metrics_savepoint.wgtfacq_c_dsl(),
wgtfacq_c=metrics_savepoint.wgtfacq_c(),
inv_ddqz_z_full=metrics_savepoint.inv_ddqz_z_full(),
reference_rho_at_cells_on_model_levels=metrics_savepoint.rho_ref_mc(),
reference_theta_at_cells_on_model_levels=metrics_savepoint.theta_ref_mc(),
Expand All @@ -71,7 +71,7 @@ def construct_metric_state(
ddqz_z_full_e=metrics_savepoint.ddqz_z_full_e(),
ddxt_z_full=metrics_savepoint.ddxt_z_full(),
wgtfac_e=metrics_savepoint.wgtfac_e(),
wgtfacq_e=metrics_savepoint.wgtfacq_e_dsl(),
wgtfacq_e=metrics_savepoint.wgtfacq_e(),
exner_w_implicit_weight_parameter=metrics_savepoint.vwind_impl_wgt(),
horizontal_mask_for_3d_divdamp=metrics_savepoint.hmask_dd3d(),
scaling_factor_for_3d_divdamp=metrics_savepoint.scalfac_dd3d(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def compute_zdiff_gradp_dsl( # noqa: PLR0912 [too-many-branches]
break

vertoffset_gradp = vertidx_gradp - vertoffset_gradp
vertoffset_gradp[:horizontal_start_1, :, :] = 0.0

# TODO: how do we handle these exchanges?
shape = zdiff_gradp.shape
exchange(zdiff_gradp.reshape(shape[0], -1))

return zdiff_gradp, vertoffset_gradp
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,6 @@ def list2field(
return gtx.as_field(domain, arr, allocator=allocator)


def kflip_wgtfacq(
arr: NDArray,
domain: gtx.Domain,
allocator: gtx_typing.Allocator,
) -> gtx.Field:
return gtx.as_field(domain, arr[:, ::-1], allocator=allocator) # type: ignore [arg-type] # type "ndarray[Any, dtype[Any] | Any"; expected "NDArrayObject"


def adjust_fortran_indices(inp: NDArray) -> NDArray:
"""For some Fortran arrays we need to subtract 1 to be compatible with Python indexing."""
return inp - 1
61 changes: 61 additions & 0 deletions model/common/src/icon4py/model/common/utils/field_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ICON4Py - ICON inspired code in Python and GT4Py
#
# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import array_api_compat
from gt4py import next as gtx
from gt4py.next import typing as gtx_typing


def flip(field: gtx.Field, dim: gtx.Dimension, allocator: gtx_typing.Allocator) -> gtx.Field:
"""Flip a field along a given dimension.

Args:
field: The field to flip.
dim: The dimension along which to flip the field.
allocator: Allocator to use for the output field.
"""
# Note: `allocator` needs to be passed explicitly since GT4Py fields currently don't persist how they were allocated.
xp = array_api_compat.array_namespace(field.ndarray)
flipped_array = xp.flip(field.ndarray, axis=field.domain.dims.index(dim))
return gtx.as_field(field.domain, flipped_array, allocator=allocator)


def index2offset(
index_field: gtx.Field, dim: gtx.Dimension, allocator: gtx_typing.Allocator
) -> gtx.Field:
"""Convert an index field to an offset field.

Note: Additionally clips negative indices to become zero offset as Fortran initializes some indices with `0` (which corresponds to `-1` in Python) to indicate that they are not used.
As GT4Py's unstructured domain inference is incomplete and runs over the whole domain we might useout-of-bounds offsets in intermediate computations.

Args:
index_field: Index field in Python indexing (0-based).
dim: The dimension along which to convert indices to offsets.
allocator: Allocator to use for the output field.
"""
# Note: `allocator` needs to be passed explicitly since GT4Py fields currently don't persist how they were allocated.
xp = array_api_compat.array_namespace(index_field.ndarray)

current_index = gtx.as_field(
gtx.Domain(index_field.domain[dim]),
xp.arange(
index_field.domain[dim].unit_range.start,
index_field.domain[dim].unit_range.stop,
dtype=index_field.ndarray.dtype,
),
allocator=allocator,
)
# use GT4Py's broadcasting and field arithmetic (includes clipping)
offset_field = gtx.where( # type: ignore[attr-defined]
index_field >= 0, # type: ignore[operator]
index_field - current_index,
0,
)

# if GT4Py embedded would propagate the allocator, we could avoid this extra conversion.
return gtx.as_field(offset_field.domain, offset_field.ndarray, allocator=allocator)
4 changes: 2 additions & 2 deletions model/common/tests/common/grid/unit_tests/test_topography.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from icon4py.model.testing import definitions, test_utils
from icon4py.model.testing.fixtures import * # noqa: F403

from ... import utils
from ... import utils_test


if TYPE_CHECKING:
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_topography_smoothing_with_serialized_data(
c2e2co=icon_grid.get_connectivity("C2E2CO").ndarray,
num_iterations=num_iterations,
array_ns=xp,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
)

assert test_utils.dallclose(topography_smoothed_ref, topography_smoothed, atol=1.0e-14)
4 changes: 2 additions & 2 deletions model/common/tests/common/grid/unit_tests/test_vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
topography_savepoint,
)

from ... import utils
from ... import utils_test


if TYPE_CHECKING:
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_compute_vertical_coordinate(
SLEVE_minimum_relative_layer_thickness_2=0.5,
lowest_layer_thickness=vertical_config.lowest_layer_thickness,
array_ns=xp,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
)

assert test_utils.dallclose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
processor_props,
)

from ... import utils
from ... import utils_test


cell_domain = h_grid.domain(dims.CellDim)
Expand All @@ -61,7 +61,7 @@ def test_compute_c_lin_e(
backend: gtx_typing.Backend,
) -> None:
xp = data_alloc.import_array_ns(backend)
func = functools.partial(compute_c_lin_e, array_ns=xp, exchange=utils.dummy_exchange)
func = functools.partial(compute_c_lin_e, array_ns=xp, exchange=utils_test.dummy_exchange)
inv_dual_edge_length = grid_savepoint.inv_dual_edge_length()
edge_cell_length = grid_savepoint.edge_cell_length()
edge_owner_mask = grid_savepoint.e_owner_mask()
Expand All @@ -74,7 +74,7 @@ def test_compute_c_lin_e(
inv_dual_edge_length.asnumpy(),
edge_owner_mask.asnumpy(),
horizontal_start,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=xp,
)
assert test_helpers.dallclose(c_lin_e, c_lin_e_ref.asnumpy())
Expand Down Expand Up @@ -162,7 +162,9 @@ def test_compute_geofac_n2s(
e2c = icon_grid.get_connectivity(dims.E2C).ndarray
c2e2c = icon_grid.get_connectivity(dims.C2E2C).ndarray
horizontal_start = icon_grid.start_index(cell_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2))
geofac_n2s = functools.partial(compute_geofac_n2s, array_ns=xp, exchange=utils.dummy_exchange)(
geofac_n2s = functools.partial(
compute_geofac_n2s, array_ns=xp, exchange=utils_test.dummy_exchange
)(
dual_edge_length.ndarray,
geofac_div.ndarray,
c2e,
Expand Down Expand Up @@ -194,7 +196,7 @@ def test_compute_geofac_grg(
horizontal_start = icon_grid.start_index(cell_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2))

geofac_grg_0, geofac_grg_1 = functools.partial(
compute_geofac_grg, array_ns=xp, exchange=utils.dummy_exchange
compute_geofac_grg, array_ns=xp, exchange=utils_test.dummy_exchange
)(
primal_normal_cell_x,
primal_normal_cell_y,
Expand Down Expand Up @@ -238,7 +240,7 @@ def test_compute_geofac_grdiv(
e2c2e = icon_grid.get_connectivity(dims.E2C2E).ndarray
horizontal_start = icon_grid.start_index(edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_2))
geofac_grdiv = functools.partial(
compute_geofac_grdiv, array_ns=xp, exchange=utils.dummy_exchange
compute_geofac_grdiv, array_ns=xp, exchange=utils_test.dummy_exchange
)(
geofac_div.ndarray,
inv_dual_edge_length.ndarray,
Expand Down Expand Up @@ -284,7 +286,7 @@ def test_compute_c_bln_avg(
divergence_averaging_central_cell_weight,
horizontal_start,
horizontal_start_p2,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=xp,
)
case base_grid.GeometryType.TORUS:
Expand All @@ -295,7 +297,7 @@ def test_compute_c_bln_avg(
divergence_averaging_central_cell_weight,
horizontal_start,
horizontal_start_p2,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=xp,
)

Expand Down Expand Up @@ -325,7 +327,9 @@ def test_compute_e_flx_avg(
horizontal_start_1 = icon_grid.start_index(edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_4))
horizontal_start_2 = icon_grid.start_index(edge_domain(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_5))

e_flx_avg = functools.partial(compute_e_flx_avg, array_ns=xp, exchange=utils.dummy_exchange)(
e_flx_avg = functools.partial(
compute_e_flx_avg, array_ns=xp, exchange=utils_test.dummy_exchange
)(
c_bln_avg,
geofac_div,
owner_mask,
Expand Down Expand Up @@ -364,7 +368,7 @@ def test_compute_cells_aw_verts(
)

cells_aw_verts = functools.partial(
compute_cells_aw_verts, array_ns=xp, exchange=utils.dummy_exchange
compute_cells_aw_verts, array_ns=xp, exchange=utils_test.dummy_exchange
)(
dual_area=dual_area,
edge_vert_length=edge_vert_length,
Expand Down Expand Up @@ -446,14 +450,14 @@ def test_compute_pos_on_tplane_e(
owner_mask,
e2c,
horizontal_start,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=xp,
)
case base_grid.GeometryType.TORUS:
pos_on_tplane_e_x, pos_on_tplane_e_y = compute_pos_on_tplane_e_x_y_torus(
dual_edge_length,
e2c,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=xp,
)
assert test_helpers.dallclose(pos_on_tplane_e_x, pos_on_tplane_e_x_ref, atol=1e-8, rtol=1e-9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
processor_props,
)

from ... import utils
from ... import utils_test


if TYPE_CHECKING:
Expand Down Expand Up @@ -205,7 +205,7 @@ def test_rbf_interpolation_coeffs_cell(
horizontal_end,
grid.global_properties.domain_length,
grid.global_properties.domain_height,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=data_alloc.import_array_ns(backend),
)

Expand Down Expand Up @@ -282,7 +282,7 @@ def test_rbf_interpolation_coeffs_vertex(
horizontal_end,
grid.global_properties.domain_length,
grid.global_properties.domain_height,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=data_alloc.import_array_ns(backend),
)

Expand Down Expand Up @@ -361,7 +361,7 @@ def test_rbf_interpolation_coeffs_edge(
horizontal_end,
grid.global_properties.domain_length,
grid.global_properties.domain_height,
exchange=utils.dummy_exchange,
exchange=utils_test.dummy_exchange,
array_ns=data_alloc.import_array_ns(backend),
)

Expand Down
Loading
Loading