Skip to content

Commit f0d747e

Browse files
committed
Merge branch 'main' into add_single_precision_dycore_part1
2 parents d031445 + 74ac45e commit f0d747e

42 files changed

Lines changed: 1464 additions & 553 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

model/atmosphere/diffusion/tests/diffusion/stencil_tests/test_apply_diffusion_to_vn.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,3 @@ def input_data(self, grid: base.Grid) -> dict:
170170
vertical_start=0,
171171
vertical_end=grid.num_levels,
172172
)
173-
174-
175-
@pytest.mark.continuous_benchmarking
176-
class TestApplyDiffusionToVnContinuousBenchmarking(TestApplyDiffusionToVn):
177-
pass

model/common/src/icon4py/model/common/decomposition/definitions.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
from __future__ import annotations
1010

11+
import dataclasses
1112
import functools
1213
import logging
1314
from collections.abc import Sequence
14-
from dataclasses import dataclass
15-
from enum import IntEnum
15+
from enum import Enum
1616
from typing import Any, Literal, Protocol, overload, runtime_checkable
1717

1818
import dace # type: ignore[import-untyped]
19+
import gt4py.next as gtx
1920
import numpy as np
20-
from gt4py.next import Dimension, Field
2121

2222
from icon4py.model.common import utils
2323
from icon4py.model.common.orchestration.halo_exchange import DummyNestedSDFG
@@ -34,7 +34,7 @@ class ProcessProperties(Protocol):
3434
comm_size: int
3535

3636

37-
@dataclass(frozen=True, init=False)
37+
@dataclasses.dataclass(frozen=True, init=False)
3838
class SingleNodeProcessProperties(ProcessProperties):
3939
comm: None
4040
rank: int
@@ -69,14 +69,14 @@ def __call__(self) -> int:
6969

7070

7171
class DecompositionInfo:
72-
class EntryType(IntEnum):
72+
class EntryType(int, Enum):
7373
ALL = 0
7474
OWNED = 1
7575
HALO = 2
7676

7777
@utils.chainable
7878
def with_dimension(
79-
self, dim: Dimension, global_index: data_alloc.NDArray, owner_mask: data_alloc.NDArray
79+
self, dim: gtx.Dimension, global_index: data_alloc.NDArray, owner_mask: data_alloc.NDArray
8080
) -> None:
8181
self._global_index[dim] = global_index
8282
self._owner_mask[dim] = owner_mask
@@ -87,8 +87,8 @@ def __init__(
8787
num_edges: int | None = None,
8888
num_vertices: int | None = None,
8989
):
90-
self._global_index: dict[Dimension, data_alloc.NDArray] = {}
91-
self._owner_mask: dict[Dimension, data_alloc.NDArray] = {}
90+
self._global_index: dict[gtx.Dimension, data_alloc.NDArray] = {}
91+
self._owner_mask: dict[gtx.Dimension, data_alloc.NDArray] = {}
9292
self._num_vertices = num_vertices
9393
self._num_cells = num_cells
9494
self._num_edges = num_edges
@@ -106,7 +106,7 @@ def num_vertices(self) -> int | None:
106106
return self._num_vertices
107107

108108
def local_index(
109-
self, dim: Dimension, entry_type: EntryType = EntryType.ALL
109+
self, dim: gtx.Dimension, entry_type: EntryType = EntryType.ALL
110110
) -> data_alloc.NDArray:
111111
match entry_type:
112112
case DecompositionInfo.EntryType.ALL:
@@ -120,7 +120,7 @@ def local_index(
120120
mask = self._owner_mask[dim]
121121
return index[mask]
122122

123-
def _to_local_index(self, dim: Dimension) -> data_alloc.NDArray:
123+
def _to_local_index(self, dim: gtx.Dimension) -> data_alloc.NDArray:
124124
data = self._global_index[dim]
125125
assert data.ndim == 1
126126
if isinstance(data, np.ndarray):
@@ -131,11 +131,11 @@ def _to_local_index(self, dim: Dimension) -> data_alloc.NDArray:
131131
xp.arange(data.shape[0])
132132
return xp.arange(data.shape[0])
133133

134-
def owner_mask(self, dim: Dimension) -> data_alloc.NDArray:
134+
def owner_mask(self, dim: gtx.Dimension) -> data_alloc.NDArray:
135135
return self._owner_mask[dim]
136136

137137
def global_index(
138-
self, dim: Dimension, entry_type: EntryType = EntryType.ALL
138+
self, dim: gtx.Dimension, entry_type: EntryType = EntryType.ALL
139139
) -> data_alloc.NDArray:
140140
match entry_type:
141141
case DecompositionInfo.EntryType.ALL:
@@ -156,30 +156,45 @@ def is_ready(self) -> bool: ...
156156

157157
@runtime_checkable
158158
class ExchangeRuntime(Protocol):
159-
def exchange(self, dim: Dimension, *fields: Field) -> ExchangeResult: ...
159+
@overload
160+
def exchange(self, dim: gtx.Dimension, *fields: gtx.Field) -> ExchangeResult: ...
160161

161-
def exchange_and_wait(self, dim: Dimension, *fields: Field) -> None: ...
162+
@overload
163+
def exchange(self, dim: gtx.Dimension, *buffers: data_alloc.NDArray) -> ExchangeResult: ...
164+
165+
@overload
166+
def exchange_and_wait(self, dim: gtx.Dimension, *fields: gtx.Field) -> None: ...
167+
168+
@overload
169+
def exchange_and_wait(self, dim: gtx.Dimension, *buffers: data_alloc.NDArray) -> None: ...
162170

163171
def get_size(self) -> int: ...
164172

165173
def my_rank(self) -> int: ...
166174

175+
def __str__(self) -> str:
176+
return f"{self.__class__} (rank = {self.my_rank()} / {self.get_size()})"
177+
167178

168-
@dataclass
179+
@dataclasses.dataclass
169180
class SingleNodeExchange:
170-
def exchange(self, dim: Dimension, *fields: Field) -> ExchangeResult:
181+
def exchange(
182+
self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray
183+
) -> ExchangeResult:
171184
return SingleNodeResult()
172185

173-
def exchange_and_wait(self, dim: Dimension, *fields: Field) -> None:
174-
return
186+
def exchange_and_wait(
187+
self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray
188+
) -> None:
189+
return None
175190

176191
def my_rank(self) -> int:
177192
return 0
178193

179194
def get_size(self) -> int:
180195
return 1
181196

182-
def __call__(self, *args: Any, dim: Dimension, wait: bool = True) -> ExchangeResult | None: # type: ignore[return] # return statment in else condition
197+
def __call__(self, *args: Any, dim: gtx.Dimension, wait: bool = True) -> ExchangeResult | None: # type: ignore[return] # return statment in else condition
183198
"""Perform a halo exchange operation.
184199
185200
Args:
@@ -198,7 +213,9 @@ def __call__(self, *args: Any, dim: Dimension, wait: bool = True) -> ExchangeRes
198213

199214
# Implementation of DaCe SDFGConvertible interface
200215
# For more see [dace repo]/dace/frontend/python/common.py#[class SDFGConvertible]
201-
def dace__sdfg__(self, *args: Any, dim: Dimension, wait: bool = True) -> dace.sdfg.sdfg.SDFG:
216+
def dace__sdfg__(
217+
self, *args: Any, dim: gtx.Dimension, wait: bool = True
218+
) -> dace.sdfg.sdfg.SDFG:
202219
sdfg = DummyNestedSDFG().__sdfg__()
203220
sdfg.name = "_halo_exchange_"
204221
return sdfg
@@ -234,15 +251,17 @@ def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]:
234251
...
235252

236253

237-
@dataclass
254+
@dataclasses.dataclass
238255
class HaloExchangeWait:
239256
exchange_object: SingleNodeExchange # maintain the same interface with the MPI counterpart
240257

241258
def __call__(self, communication_handle: SingleNodeResult) -> None:
242259
communication_handle.wait()
243260

244261
# Implementation of DaCe SDFGConvertible interface
245-
def dace__sdfg__(self, *args: Any, dim: Dimension, wait: bool = True) -> dace.sdfg.sdfg.SDFG:
262+
def dace__sdfg__(
263+
self, *args: Any, dim: gtx.Dimension, wait: bool = True
264+
) -> dace.sdfg.sdfg.SDFG:
246265
sdfg = DummyNestedSDFG().__sdfg__()
247266
sdfg.name = "_halo_exchange_wait_"
248267
return sdfg
@@ -344,3 +363,6 @@ def create_single_node_exchange(
344363
props: SingleNodeProcessProperties, decomp_info: DecompositionInfo
345364
) -> ExchangeRuntime:
346365
return SingleNodeExchange()
366+
367+
368+
single_node_default = SingleNodeExchange()

model/common/src/icon4py/model/common/decomposition/mpi_decomposition.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
ghex = None
4747
unstructured = None
4848

49-
5049
if TYPE_CHECKING:
5150
import mpi4py.MPI # type: ignore [import-not-found]
5251

@@ -203,50 +202,65 @@ def _create_pattern(self, horizontal_dim: gtx.Dimension) -> DomainDescriptor:
203202
def _slice_field_based_on_dim(self, field: gtx.Field, dim: gtx.Dimension) -> data_alloc.NDArray:
204203
"""
205204
Slices the field based on the dimension passed in.
205+
206+
This operation is *necessary* for the use inside FORTRAN as there fields are larger than the grid (nproma size). where it does not do anything in a purely Python setup.
207+
the granule context where fields otherwise have length nproma.
206208
"""
207209
if dim == dims.VertexDim:
208-
return field.ndarray[: self._decomposition_info.num_vertices, :]
210+
return field.ndarray[: self._decomposition_info.num_vertices]
209211
elif dim == dims.EdgeDim:
210-
return field.ndarray[: self._decomposition_info.num_edges, :]
212+
return field.ndarray[: self._decomposition_info.num_edges]
211213
elif dim == dims.CellDim:
212-
return field.ndarray[: self._decomposition_info.num_cells, :]
214+
return field.ndarray[: self._decomposition_info.num_cells]
213215
else:
214216
raise ValueError(f"Unknown dimension {dim}")
215217

216-
def _get_applied_pattern(self, dim: gtx.Dimension, f: gtx.Field) -> str:
217-
# TODO(havogt): the cache is never cleared, consider using functools.lru_cache in a bigger refactoring.
218-
assert hasattr(f, "__gt_buffer_info__")
219-
# dimension and buffer_info uniquely identifies the exchange pattern
220-
key = (dim, f.__gt_buffer_info__.hash_key)
221-
try:
222-
return self._applied_patterns_cache[key]
223-
except KeyError:
224-
assert dim in f.domain.dims
225-
array = self._slice_field_based_on_dim(f, dim)
226-
self._applied_patterns_cache[key] = self._patterns[dim](
227-
make_field_descriptor(
228-
self._domain_descriptors[dim],
229-
array,
230-
arch=Architecture.CPU if isinstance(f, np.ndarray) else Architecture.GPU,
218+
def _make_field_descriptor(self, dim: gtx.Dimension, array: data_alloc.NDArray) -> Any:
219+
return make_field_descriptor(
220+
self._domain_descriptors[dim],
221+
array,
222+
arch=Architecture.CPU if isinstance(array, np.ndarray) else Architecture.GPU,
223+
)
224+
225+
def _get_applied_pattern(self, dim: gtx.Dimension, f: gtx.Field | data_alloc.NDArray) -> str:
226+
if isinstance(f, gtx.Field):
227+
assert hasattr(f, "__gt_buffer_info__")
228+
# dimension and buffer_info uniquely identifies the exchange pattern
229+
# TODO(havogt): the cache is never cleared, consider using functools.lru_cache in a bigger refactoring.
230+
key = (dim, f.__gt_buffer_info__.hash_key)
231+
try:
232+
return self._applied_patterns_cache[key]
233+
except KeyError:
234+
assert dim in f.domain.dims
235+
array = self._slice_field_based_on_dim(f, dim)
236+
self._applied_patterns_cache[key] = self._patterns[dim](
237+
self._make_field_descriptor(dim, array)
231238
)
232-
)
233-
return self._applied_patterns_cache[key]
239+
return self._applied_patterns_cache[key]
240+
else:
241+
assert f.ndim in (1, 2), "Buffers must be 1d or 2d"
242+
return self._patterns[dim](self._make_field_descriptor(dim, f))
234243

235-
def exchange(self, dim: gtx.Dimension, *fields: gtx.Field) -> MultiNodeResult:
244+
def exchange(
245+
self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray
246+
) -> MultiNodeResult:
236247
"""
237248
Exchange method that slices the fields based on the dimension and then performs halo exchange.
238-
239-
This operation is *necessary* for the use inside FORTRAN as there fields are larger than the grid (nproma size). where it does not do anything in a purely Python setup.
240-
the granule context where fields otherwise have length nproma.
241249
"""
250+
assert (
251+
dim in dims.MAIN_HORIZONTAL_DIMENSIONS.values()
252+
), f"first dimension must be one of ({dims.MAIN_HORIZONTAL_DIMENSIONS.values()})"
253+
242254
applied_patterns = [self._get_applied_pattern(dim, f) for f in fields]
243255
# With https://github.com/ghex-org/GHEX/pull/186, ghex will schedule/sync work on the default stream,
244256
# otherwise we need an explicit device synchronize here.
245257
handle = self._comm.exchange(applied_patterns)
246258
log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' initiated.")
247259
return MultiNodeResult(handle, applied_patterns)
248260

249-
def exchange_and_wait(self, dim: gtx.Dimension, *fields: gtx.Field) -> None:
261+
def exchange_and_wait(
262+
self, dim: gtx.Dimension, *fields: gtx.Field | data_alloc.NDArray
263+
) -> None:
250264
res = self.exchange(dim, *fields)
251265
res.wait()
252266
log.debug(f"exchange for {len(fields)} fields of dimension ='{dim.value}' done.")

0 commit comments

Comments
 (0)