Skip to content

Commit 76fc293

Browse files
authored
ICON stencil from sir (#858)
Rewrites the ICON laplacian stencil in SIR (with dawn4py) and adds it to dawn4py tests.
1 parent c6ff042 commit 76fc293

4 files changed

Lines changed: 324 additions & 1 deletion

File tree

dawn/examples/python/unstructured_stencil.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
def main(args: argparse.Namespace):
3737
interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0)
3838

39-
# create the out = in[i+1] statement
4039
body_ast = sir_utils.make_ast(
4140
[
4241
sir_utils.make_assignment_stmt(

dawn/test/integration-test/dawn4py-tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ endfunction()
4040
add_python_example(NAME copy_stencil VERIFY)
4141
add_python_example(NAME hori_diff_stencil VERIFY)
4242
add_python_example(NAME tridiagonal_solve_stencil VERIFY)
43+
add_python_example(NAME ICON_laplacian_stencil VERIFY)
4344
add_python_example(NAME tridiagonal_solve_unstructured)
4445
add_python_example(NAME global_index_stencil)
4546
add_python_example(NAME unstructured_stencil)
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#!/usr/bin/env python
2+
3+
##===-----------------------------------------------------------------------------*- Python -*-===##
4+
## _
5+
## | |
6+
## __| | __ ___ ___ ___
7+
## / _` |/ _` \ \ /\ / / '_ |
8+
## | (_| | (_| |\ V V /| | | |
9+
## \__,_|\__,_| \_/\_/ |_| |_| - Compiler Toolchain
10+
##
11+
##
12+
## This file is distributed under the MIT License (MIT).
13+
## See LICENSE.txt for details.
14+
##
15+
##===------------------------------------------------------------------------------------------===##
16+
17+
"""Generate input for the ICON Laplacian stencil test"""
18+
19+
import os
20+
21+
import dawn4py
22+
from dawn4py.serialization import SIR
23+
from dawn4py.serialization import utils as sir_utils
24+
from google.protobuf.json_format import MessageToJson, Parse
25+
26+
27+
28+
def main():
29+
stencil_name = "ICON_laplacian_stencil"
30+
gen_outputfile = f"{stencil_name}.cpp"
31+
sir_outputfile = f"{stencil_name}.sir"
32+
33+
interval = sir_utils.make_interval(SIR.Interval.Start, SIR.Interval.End, 0, 0)
34+
35+
body_ast = sir_utils.make_ast(
36+
[
37+
sir_utils.make_assignment_stmt(
38+
sir_utils.make_field_access_expr("rot_vec"),
39+
sir_utils.make_reduction_over_neighbor_expr(
40+
op="+",
41+
init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
42+
rhs=sir_utils.make_binary_operator(
43+
sir_utils.make_field_access_expr("vec"),
44+
"*",
45+
sir_utils.make_field_access_expr("geofac_rot")),
46+
lhs_location=SIR.LocationType.Value("Vertex"),
47+
rhs_location=SIR.LocationType.Value("Edge"),
48+
),
49+
"=",
50+
),
51+
sir_utils.make_assignment_stmt(
52+
sir_utils.make_field_access_expr("div_vec"),
53+
sir_utils.make_reduction_over_neighbor_expr(
54+
op="+",
55+
init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
56+
rhs=sir_utils.make_binary_operator(
57+
sir_utils.make_field_access_expr("vec"),
58+
"*",
59+
sir_utils.make_field_access_expr("geofac_div")),
60+
lhs_location=SIR.LocationType.Value("Cell"),
61+
rhs_location=SIR.LocationType.Value("Edge"),
62+
),
63+
"=",
64+
),
65+
sir_utils.make_assignment_stmt(
66+
sir_utils.make_field_access_expr("nabla2t1_vec"),
67+
sir_utils.make_reduction_over_neighbor_expr(
68+
op="+",
69+
init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
70+
rhs=sir_utils.make_field_access_expr("rot_vec"),
71+
lhs_location=SIR.LocationType.Value("Edge"),
72+
rhs_location=SIR.LocationType.Value("Vertex"),
73+
weights=sir_utils.make_weights([-1.0, 1.0])
74+
),
75+
"=",
76+
),
77+
sir_utils.make_assignment_stmt(
78+
sir_utils.make_field_access_expr("nabla2t1_vec"),
79+
sir_utils.make_binary_operator(
80+
sir_utils.make_binary_operator(
81+
sir_utils.make_field_access_expr("tangent_orientation"),
82+
"*",
83+
sir_utils.make_field_access_expr("nabla2t1_vec")),
84+
"/",
85+
sir_utils.make_field_access_expr("primal_edge_length")),
86+
"=",
87+
),
88+
sir_utils.make_assignment_stmt(
89+
sir_utils.make_field_access_expr("nabla2t2_vec"),
90+
sir_utils.make_reduction_over_neighbor_expr(
91+
op="+",
92+
init=sir_utils.make_literal_access_expr("0.0", SIR.BuiltinType.Double),
93+
rhs=sir_utils.make_field_access_expr("div_vec"),
94+
lhs_location=SIR.LocationType.Value("Edge"),
95+
rhs_location=SIR.LocationType.Value("Cell"),
96+
weights=sir_utils.make_weights([-1.0, 1.0])
97+
),
98+
"=",
99+
),
100+
sir_utils.make_assignment_stmt(
101+
sir_utils.make_field_access_expr("nabla2t2_vec"),
102+
sir_utils.make_binary_operator(
103+
sir_utils.make_field_access_expr("nabla2t2_vec"),
104+
"/",
105+
sir_utils.make_field_access_expr("dual_edge_length")),
106+
"=",
107+
),
108+
sir_utils.make_assignment_stmt(
109+
sir_utils.make_field_access_expr("nabla2_vec"),
110+
sir_utils.make_binary_operator(
111+
sir_utils.make_field_access_expr("nabla2t2_vec"),
112+
"-",
113+
sir_utils.make_field_access_expr("nabla2t1_vec")),
114+
"=",
115+
),
116+
]
117+
)
118+
119+
vertical_region_stmt = sir_utils.make_vertical_region_decl_stmt(
120+
body_ast, interval, SIR.VerticalRegion.Forward
121+
)
122+
123+
sir = sir_utils.make_sir(
124+
gen_outputfile,
125+
SIR.GridType.Value("Unstructured"),
126+
[
127+
sir_utils.make_stencil(
128+
stencil_name,
129+
sir_utils.make_ast([vertical_region_stmt]),
130+
[
131+
sir_utils.make_field(
132+
"vec",
133+
sir_utils.make_field_dimensions_unstructured(
134+
[SIR.LocationType.Value("Edge")], 1
135+
),
136+
),
137+
sir_utils.make_field(
138+
"div_vec",
139+
sir_utils.make_field_dimensions_unstructured(
140+
[SIR.LocationType.Value("Cell")], 1
141+
),
142+
),
143+
sir_utils.make_field(
144+
"rot_vec",
145+
sir_utils.make_field_dimensions_unstructured(
146+
[SIR.LocationType.Value("Vertex")], 1
147+
),
148+
),
149+
sir_utils.make_field(
150+
"nabla2t1_vec",
151+
sir_utils.make_field_dimensions_unstructured(
152+
[SIR.LocationType.Value("Edge")], 1
153+
),
154+
),
155+
sir_utils.make_field(
156+
"nabla2t2_vec",
157+
sir_utils.make_field_dimensions_unstructured(
158+
[SIR.LocationType.Value("Edge")], 1
159+
),
160+
),
161+
sir_utils.make_field(
162+
"nabla2_vec",
163+
sir_utils.make_field_dimensions_unstructured(
164+
[SIR.LocationType.Value("Edge")], 1
165+
),
166+
),
167+
sir_utils.make_field(
168+
"primal_edge_length",
169+
sir_utils.make_field_dimensions_unstructured(
170+
[SIR.LocationType.Value("Edge")], 1
171+
),
172+
),
173+
sir_utils.make_field(
174+
"dual_edge_length",
175+
sir_utils.make_field_dimensions_unstructured(
176+
[SIR.LocationType.Value("Edge")], 1
177+
),
178+
),
179+
sir_utils.make_field(
180+
"tangent_orientation",
181+
sir_utils.make_field_dimensions_unstructured(
182+
[SIR.LocationType.Value("Edge")], 1
183+
),
184+
),
185+
sir_utils.make_field(
186+
"geofac_rot",
187+
sir_utils.make_field_dimensions_unstructured(
188+
[SIR.LocationType.Value("Vertex"), SIR.LocationType.Value("Edge")], 1
189+
),
190+
),
191+
sir_utils.make_field(
192+
"geofac_div",
193+
sir_utils.make_field_dimensions_unstructured(
194+
[SIR.LocationType.Value("Cell"), SIR.LocationType.Value("Edge")], 1
195+
),
196+
),
197+
],
198+
),
199+
],
200+
)
201+
202+
# write SIR to file (for debugging purposes)
203+
f = open(sir_outputfile, "w")
204+
f.write(MessageToJson(sir))
205+
f.close()
206+
207+
# compile
208+
code = dawn4py.compile(sir, backend="c++-naive-ico")
209+
210+
# write to file
211+
print(f"Writing generated code to '{gen_outputfile}'")
212+
with open(gen_outputfile, "w") as f:
213+
f.write(code)
214+
215+
216+
if __name__ == "__main__":
217+
main()
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
//---- Preprocessor defines ----
2+
#define DAWN_GENERATED 1
3+
#undef DAWN_BACKEND_T
4+
#define DAWN_BACKEND_T CXXNAIVEICO
5+
#include <driver-includes/unstructured_interface.hpp>
6+
7+
//---- Globals ----
8+
9+
//---- Stencils ----
10+
namespace dawn_generated{
11+
namespace cxxnaiveico{
12+
template<typename LibTag>
13+
class ICON_laplacian_stencil {
14+
private:
15+
16+
struct stencil_64 {
17+
dawn::mesh_t<LibTag> const& m_mesh;
18+
int m_k_size;
19+
dawn::edge_field_t<LibTag, double>& m_vec;
20+
dawn::cell_field_t<LibTag, double>& m_div_vec;
21+
dawn::vertex_field_t<LibTag, double>& m_rot_vec;
22+
dawn::edge_field_t<LibTag, double>& m_nabla2t1_vec;
23+
dawn::edge_field_t<LibTag, double>& m_nabla2t2_vec;
24+
dawn::edge_field_t<LibTag, double>& m_nabla2_vec;
25+
dawn::edge_field_t<LibTag, double>& m_primal_edge_length;
26+
dawn::edge_field_t<LibTag, double>& m_dual_edge_length;
27+
dawn::edge_field_t<LibTag, double>& m_tangent_orientation;
28+
dawn::sparse_vertex_field_t<LibTag, double>& m_geofac_rot;
29+
dawn::sparse_cell_field_t<LibTag, double>& m_geofac_div;
30+
public:
31+
32+
stencil_64(dawn::mesh_t<LibTag> const &mesh, int k_size, dawn::edge_field_t<LibTag, double>&vec, dawn::cell_field_t<LibTag, double>&div_vec, dawn::vertex_field_t<LibTag, double>&rot_vec, dawn::edge_field_t<LibTag, double>&nabla2t1_vec, dawn::edge_field_t<LibTag, double>&nabla2t2_vec, dawn::edge_field_t<LibTag, double>&nabla2_vec, dawn::edge_field_t<LibTag, double>&primal_edge_length, dawn::edge_field_t<LibTag, double>&dual_edge_length, dawn::edge_field_t<LibTag, double>&tangent_orientation, dawn::sparse_vertex_field_t<LibTag, double>&geofac_rot, dawn::sparse_cell_field_t<LibTag, double>&geofac_div) : m_mesh(mesh), m_k_size(k_size), m_vec(vec), m_div_vec(div_vec), m_rot_vec(rot_vec), m_nabla2t1_vec(nabla2t1_vec), m_nabla2t2_vec(nabla2t2_vec), m_nabla2_vec(nabla2_vec), m_primal_edge_length(primal_edge_length), m_dual_edge_length(dual_edge_length), m_tangent_orientation(tangent_orientation), m_geofac_rot(geofac_rot), m_geofac_div(geofac_div){}
33+
34+
~stencil_64() {
35+
}
36+
37+
void sync_storages() {
38+
}
39+
static constexpr dawn::driver::unstructured_extent vec_extent = {false, 0,0};
40+
static constexpr dawn::driver::unstructured_extent div_vec_extent = {false, 0,0};
41+
static constexpr dawn::driver::unstructured_extent rot_vec_extent = {false, 0,0};
42+
static constexpr dawn::driver::unstructured_extent nabla2t1_vec_extent = {false, 0,0};
43+
static constexpr dawn::driver::unstructured_extent nabla2t2_vec_extent = {false, 0,0};
44+
static constexpr dawn::driver::unstructured_extent nabla2_vec_extent = {false, 0,0};
45+
static constexpr dawn::driver::unstructured_extent primal_edge_length_extent = {false, 0,0};
46+
static constexpr dawn::driver::unstructured_extent dual_edge_length_extent = {false, 0,0};
47+
static constexpr dawn::driver::unstructured_extent tangent_orientation_extent = {false, 0,0};
48+
static constexpr dawn::driver::unstructured_extent geofac_rot_extent = {false, 0,0};
49+
static constexpr dawn::driver::unstructured_extent geofac_div_extent = {false, 0,0};
50+
51+
void run() {
52+
using dawn::deref;
53+
{
54+
for(int k = 0+0; k <= ( m_k_size == 0 ? 0 : (m_k_size - 1)) + 0+0; ++k) {
55+
for(auto const& loc : getCells(LibTag{}, m_mesh)) {
56+
int m_sparse_dimension_idx = 0;
57+
m_div_vec(deref(LibTag{}, loc),k+0) = reduceEdgeToCell(LibTag{}, m_mesh,loc, (::dawn::float_type) 0.0, [&](auto& lhs, auto const& red_loc) { lhs += (m_vec(deref(LibTag{}, red_loc),k+0) * m_geofac_div(deref(LibTag{}, loc),m_sparse_dimension_idx, k+0));
58+
m_sparse_dimension_idx++;
59+
return lhs;
60+
});
61+
} for(auto const& loc : getEdges(LibTag{}, m_mesh)) {
62+
int m_sparse_dimension_idx = 0;
63+
m_nabla2t2_vec(deref(LibTag{}, loc),k+0) = reduceCellToEdge(LibTag{}, m_mesh,loc, (::dawn::float_type) 0.0, [&](auto& lhs, auto const& red_loc, auto const& weight) {
64+
lhs += weight * m_div_vec(deref(LibTag{}, red_loc),k+0);
65+
m_sparse_dimension_idx++;
66+
return lhs;
67+
}, std::vector<double>({-1.000000, 1.000000}));
68+
} for(auto const& loc : getEdges(LibTag{}, m_mesh)) {
69+
m_nabla2t2_vec(deref(LibTag{}, loc),k+0) = (m_nabla2t2_vec(deref(LibTag{}, loc),k+0) / m_dual_edge_length(deref(LibTag{}, loc),k+0));
70+
} for(auto const& loc : getVertices(LibTag{}, m_mesh)) {
71+
int m_sparse_dimension_idx = 0;
72+
m_rot_vec(deref(LibTag{}, loc),k+0) = reduceEdgeToVertex(LibTag{}, m_mesh,loc, (::dawn::float_type) 0.0, [&](auto& lhs, auto const& red_loc) { lhs += (m_vec(deref(LibTag{}, red_loc),k+0) * m_geofac_rot(deref(LibTag{}, loc),m_sparse_dimension_idx, k+0));
73+
m_sparse_dimension_idx++;
74+
return lhs;
75+
});
76+
} for(auto const& loc : getEdges(LibTag{}, m_mesh)) {
77+
int m_sparse_dimension_idx = 0;
78+
m_nabla2t1_vec(deref(LibTag{}, loc),k+0) = reduceVertexToEdge(LibTag{}, m_mesh,loc, (::dawn::float_type) 0.0, [&](auto& lhs, auto const& red_loc, auto const& weight) {
79+
lhs += weight * m_rot_vec(deref(LibTag{}, red_loc),k+0);
80+
m_sparse_dimension_idx++;
81+
return lhs;
82+
}, std::vector<double>({-1.000000, 1.000000}));
83+
} for(auto const& loc : getEdges(LibTag{}, m_mesh)) {
84+
m_nabla2t1_vec(deref(LibTag{}, loc),k+0) = ((m_tangent_orientation(deref(LibTag{}, loc),k+0) * m_nabla2t1_vec(deref(LibTag{}, loc),k+0)) / m_primal_edge_length(deref(LibTag{}, loc),k+0));
85+
} for(auto const& loc : getEdges(LibTag{}, m_mesh)) {
86+
m_nabla2_vec(deref(LibTag{}, loc),k+0) = (m_nabla2t2_vec(deref(LibTag{}, loc),k+0) - m_nabla2t1_vec(deref(LibTag{}, loc),k+0));
87+
} }} sync_storages();
88+
}
89+
};
90+
static constexpr const char* s_name = "ICON_laplacian_stencil";
91+
stencil_64 m_stencil_64;
92+
public:
93+
94+
ICON_laplacian_stencil(const ICON_laplacian_stencil&) = delete;
95+
96+
// Members
97+
98+
ICON_laplacian_stencil(const dawn::mesh_t<LibTag> &mesh, int k_size, dawn::edge_field_t<LibTag, double>& vec, dawn::cell_field_t<LibTag, double>& div_vec, dawn::vertex_field_t<LibTag, double>& rot_vec, dawn::edge_field_t<LibTag, double>& nabla2t1_vec, dawn::edge_field_t<LibTag, double>& nabla2t2_vec, dawn::edge_field_t<LibTag, double>& nabla2_vec, dawn::edge_field_t<LibTag, double>& primal_edge_length, dawn::edge_field_t<LibTag, double>& dual_edge_length, dawn::edge_field_t<LibTag, double>& tangent_orientation, dawn::sparse_vertex_field_t<LibTag, double>& geofac_rot, dawn::sparse_cell_field_t<LibTag, double>& geofac_div) : m_stencil_64(mesh, k_size,vec,div_vec,rot_vec,nabla2t1_vec,nabla2t2_vec,nabla2_vec,primal_edge_length,dual_edge_length,tangent_orientation,geofac_rot,geofac_div){}
99+
100+
void run() {
101+
m_stencil_64.run();
102+
;
103+
}
104+
};
105+
} // namespace cxxnaiveico
106+
} // namespace dawn_generated

0 commit comments

Comments
 (0)