88
99from __future__ import annotations
1010
11+ import dataclasses
1112import functools
1213import logging
1314from collections .abc import Sequence
14- from dataclasses import dataclass
15- from enum import IntEnum
15+ from enum import Enum
1616from typing import Any , Literal , Protocol , overload , runtime_checkable
1717
1818import dace # type: ignore[import-untyped]
19+ import gt4py .next as gtx
1920import numpy as np
20- from gt4py .next import Dimension , Field
2121
2222from icon4py .model .common import utils
2323from icon4py .model .common .orchestration .halo_exchange import DummyNestedSDFG
@@ -34,7 +34,7 @@ class ProcessProperties(Protocol):
3434 comm_size : int
3535
3636
37- @dataclass (frozen = True , init = False )
37+ @dataclasses . dataclass (frozen = True , init = False )
3838class SingleNodeProcessProperties (ProcessProperties ):
3939 comm : None
4040 rank : int
@@ -69,14 +69,14 @@ def __call__(self) -> int:
6969
7070
7171class DecompositionInfo :
72- class EntryType (IntEnum ):
72+ class EntryType (int , Enum ):
7373 ALL = 0
7474 OWNED = 1
7575 HALO = 2
7676
7777 @utils .chainable
7878 def with_dimension (
79- self , dim : Dimension , global_index : data_alloc .NDArray , owner_mask : data_alloc .NDArray
79+ self , dim : gtx . Dimension , global_index : data_alloc .NDArray , owner_mask : data_alloc .NDArray
8080 ) -> None :
8181 self ._global_index [dim ] = global_index
8282 self ._owner_mask [dim ] = owner_mask
@@ -87,8 +87,8 @@ def __init__(
8787 num_edges : int | None = None ,
8888 num_vertices : int | None = None ,
8989 ):
90- self ._global_index : dict [Dimension , data_alloc .NDArray ] = {}
91- self ._owner_mask : dict [Dimension , data_alloc .NDArray ] = {}
90+ self ._global_index : dict [gtx . Dimension , data_alloc .NDArray ] = {}
91+ self ._owner_mask : dict [gtx . Dimension , data_alloc .NDArray ] = {}
9292 self ._num_vertices = num_vertices
9393 self ._num_cells = num_cells
9494 self ._num_edges = num_edges
@@ -106,7 +106,7 @@ def num_vertices(self) -> int | None:
106106 return self ._num_vertices
107107
108108 def local_index (
109- self , dim : Dimension , entry_type : EntryType = EntryType .ALL
109+ self , dim : gtx . Dimension , entry_type : EntryType = EntryType .ALL
110110 ) -> data_alloc .NDArray :
111111 match entry_type :
112112 case DecompositionInfo .EntryType .ALL :
@@ -120,7 +120,7 @@ def local_index(
120120 mask = self ._owner_mask [dim ]
121121 return index [mask ]
122122
123- def _to_local_index (self , dim : Dimension ) -> data_alloc .NDArray :
123+ def _to_local_index (self , dim : gtx . Dimension ) -> data_alloc .NDArray :
124124 data = self ._global_index [dim ]
125125 assert data .ndim == 1
126126 if isinstance (data , np .ndarray ):
@@ -131,11 +131,11 @@ def _to_local_index(self, dim: Dimension) -> data_alloc.NDArray:
131131 xp .arange (data .shape [0 ])
132132 return xp .arange (data .shape [0 ])
133133
134- def owner_mask (self , dim : Dimension ) -> data_alloc .NDArray :
134+ def owner_mask (self , dim : gtx . Dimension ) -> data_alloc .NDArray :
135135 return self ._owner_mask [dim ]
136136
137137 def global_index (
138- self , dim : Dimension , entry_type : EntryType = EntryType .ALL
138+ self , dim : gtx . Dimension , entry_type : EntryType = EntryType .ALL
139139 ) -> data_alloc .NDArray :
140140 match entry_type :
141141 case DecompositionInfo .EntryType .ALL :
@@ -156,30 +156,45 @@ def is_ready(self) -> bool: ...
156156
157157@runtime_checkable
158158class ExchangeRuntime (Protocol ):
159- def exchange (self , dim : Dimension , * fields : Field ) -> ExchangeResult : ...
159+ @overload
160+ def exchange (self , dim : gtx .Dimension , * fields : gtx .Field ) -> ExchangeResult : ...
160161
161- def exchange_and_wait (self , dim : Dimension , * fields : Field ) -> None : ...
162+ @overload
163+ def exchange (self , dim : gtx .Dimension , * buffers : data_alloc .NDArray ) -> ExchangeResult : ...
164+
165+ @overload
166+ def exchange_and_wait (self , dim : gtx .Dimension , * fields : gtx .Field ) -> None : ...
167+
168+ @overload
169+ def exchange_and_wait (self , dim : gtx .Dimension , * buffers : data_alloc .NDArray ) -> None : ...
162170
163171 def get_size (self ) -> int : ...
164172
165173 def my_rank (self ) -> int : ...
166174
175+ def __str__ (self ) -> str :
176+ return f"{ self .__class__ } (rank = { self .my_rank ()} / { self .get_size ()} )"
177+
167178
168- @dataclass
179+ @dataclasses . dataclass
169180class SingleNodeExchange :
170- def exchange (self , dim : Dimension , * fields : Field ) -> ExchangeResult :
181+ def exchange (
182+ self , dim : gtx .Dimension , * fields : gtx .Field | data_alloc .NDArray
183+ ) -> ExchangeResult :
171184 return SingleNodeResult ()
172185
173- def exchange_and_wait (self , dim : Dimension , * fields : Field ) -> None :
174- return
186+ def exchange_and_wait (
187+ self , dim : gtx .Dimension , * fields : gtx .Field | data_alloc .NDArray
188+ ) -> None :
189+ return None
175190
176191 def my_rank (self ) -> int :
177192 return 0
178193
179194 def get_size (self ) -> int :
180195 return 1
181196
182- def __call__ (self , * args : Any , dim : Dimension , wait : bool = True ) -> ExchangeResult | None : # type: ignore[return] # return statment in else condition
197+ def __call__ (self , * args : Any , dim : gtx . Dimension , wait : bool = True ) -> ExchangeResult | None : # type: ignore[return] # return statment in else condition
183198 """Perform a halo exchange operation.
184199
185200 Args:
@@ -198,7 +213,9 @@ def __call__(self, *args: Any, dim: Dimension, wait: bool = True) -> ExchangeRes
198213
199214 # Implementation of DaCe SDFGConvertible interface
200215 # For more see [dace repo]/dace/frontend/python/common.py#[class SDFGConvertible]
201- def dace__sdfg__ (self , * args : Any , dim : Dimension , wait : bool = True ) -> dace .sdfg .sdfg .SDFG :
216+ def dace__sdfg__ (
217+ self , * args : Any , dim : gtx .Dimension , wait : bool = True
218+ ) -> dace .sdfg .sdfg .SDFG :
202219 sdfg = DummyNestedSDFG ().__sdfg__ ()
203220 sdfg .name = "_halo_exchange_"
204221 return sdfg
@@ -234,15 +251,17 @@ def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]:
234251 ...
235252
236253
237- @dataclass
254+ @dataclasses . dataclass
238255class HaloExchangeWait :
239256 exchange_object : SingleNodeExchange # maintain the same interface with the MPI counterpart
240257
241258 def __call__ (self , communication_handle : SingleNodeResult ) -> None :
242259 communication_handle .wait ()
243260
244261 # Implementation of DaCe SDFGConvertible interface
245- def dace__sdfg__ (self , * args : Any , dim : Dimension , wait : bool = True ) -> dace .sdfg .sdfg .SDFG :
262+ def dace__sdfg__ (
263+ self , * args : Any , dim : gtx .Dimension , wait : bool = True
264+ ) -> dace .sdfg .sdfg .SDFG :
246265 sdfg = DummyNestedSDFG ().__sdfg__ ()
247266 sdfg .name = "_halo_exchange_wait_"
248267 return sdfg
@@ -344,3 +363,6 @@ def create_single_node_exchange(
344363 props : SingleNodeProcessProperties , decomp_info : DecompositionInfo
345364) -> ExchangeRuntime :
346365 return SingleNodeExchange ()
366+
367+
368+ single_node_default = SingleNodeExchange ()
0 commit comments