Skip to content

Commit dd41784

Browse files
committed
test_apply_diffusion_to_vn (was made single precision ready before batch allocations)
1 parent 9a739bb commit dd41784

1 file changed

Lines changed: 25 additions & 14 deletions

File tree

model/atmosphere/diffusion/tests/diffusion/stencil_tests/test_apply_diffusion_to_vn.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_vn import apply_diffusion_to_vn
1515
from icon4py.model.common import dimension as dims
1616
from icon4py.model.common.grid import base, horizontal as h_grid
17+
from icon4py.model.common.type_alias import vpfloat, wpfloat
1718
from icon4py.model.common.utils import data_allocation as data_alloc
1819
from icon4py.model.testing.stencil_tests import StandardStaticVariants, StencilTest
1920

@@ -25,6 +26,7 @@
2526
from .test_calculate_nabla4 import calculate_nabla4_numpy
2627

2728

29+
@pytest.mark.single_precision_ready
2830
@pytest.mark.uses_concat_where
2931
@pytest.mark.continuous_benchmarking
3032
class TestApplyDiffusionToVn(StencilTest):
@@ -117,25 +119,29 @@ def reference(
117119

118120
@pytest.fixture
119121
def input_data(self, grid: base.Grid) -> dict:
120-
u_vert = data_alloc.random_field(grid, dims.VertexDim, dims.KDim)
121-
v_vert = data_alloc.random_field(grid, dims.VertexDim, dims.KDim)
122+
u_vert = data_alloc.random_field(grid, dims.VertexDim, dims.KDim, dtype=vpfloat)
123+
v_vert = data_alloc.random_field(grid, dims.VertexDim, dims.KDim, dtype=vpfloat)
122124

123-
primal_normal_vert_v1 = data_alloc.random_field(grid, dims.EdgeDim, dims.E2C2VDim)
124-
primal_normal_vert_v2 = data_alloc.random_field(grid, dims.EdgeDim, dims.E2C2VDim)
125+
primal_normal_vert_v1 = data_alloc.random_field(
126+
grid, dims.EdgeDim, dims.E2C2VDim, dtype=wpfloat
127+
)
128+
primal_normal_vert_v2 = data_alloc.random_field(
129+
grid, dims.EdgeDim, dims.E2C2VDim, dtype=wpfloat
130+
)
125131

126-
inv_vert_vert_length = data_alloc.random_field(grid, dims.EdgeDim)
127-
inv_primal_edge_length = data_alloc.random_field(grid, dims.EdgeDim)
132+
inv_vert_vert_length = data_alloc.random_field(grid, dims.EdgeDim, dtype=wpfloat)
133+
inv_primal_edge_length = data_alloc.random_field(grid, dims.EdgeDim, dtype=wpfloat)
128134

129-
area_edge = data_alloc.random_field(grid, dims.EdgeDim)
130-
kh_smag_e = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim)
131-
z_nabla2_e = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim)
132-
diff_multfac_vn = data_alloc.random_field(grid, dims.KDim)
133-
vn = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim)
134-
nudgecoeff_e = data_alloc.random_field(grid, dims.EdgeDim)
135+
area_edge = data_alloc.random_field(grid, dims.EdgeDim, dtype=wpfloat)
136+
kh_smag_e = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim, dtype=vpfloat)
137+
z_nabla2_e = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim, dtype=wpfloat)
138+
diff_multfac_vn = data_alloc.random_field(grid, dims.KDim, dtype=wpfloat)
139+
vn = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim, dtype=wpfloat)
140+
nudgecoeff_e = data_alloc.random_field(grid, dims.EdgeDim, dtype=wpfloat)
135141

136142
limited_area = grid.limited_area if hasattr(grid, "limited_area") else True
137-
fac_bdydiff_v = 5.0
138-
nudgezone_diff = 9.0
143+
fac_bdydiff_v = wpfloat(5.0)
144+
nudgezone_diff = vpfloat(9.0)
139145

140146
edge_domain = h_grid.domain(dims.EdgeDim)
141147
start_2nd_nudge_line_idx_e = grid.start_index(edge_domain(h_grid.Zone.NUDGING_LEVEL_2))
@@ -164,3 +170,8 @@ def input_data(self, grid: base.Grid) -> dict:
164170
vertical_start=0,
165171
vertical_end=grid.num_levels,
166172
)
173+
174+
175+
@pytest.mark.continuous_benchmarking
176+
class TestApplyDiffusionToVnContinuousBenchmarking(TestApplyDiffusionToVn):
177+
pass

0 commit comments

Comments
 (0)