Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 41 additions & 28 deletions model/common/src/icon4py/model/common/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import functools
import logging
from collections.abc import Callable, Mapping, Sequence
from typing import Any, Literal, TypeAlias, TypeVar
from typing import Any

from gt4py import next as gtx
from gt4py.next import backend as gtx_backend
Expand All @@ -34,8 +34,6 @@
from icon4py.model.common.utils import data_allocation as data_alloc, device_utils


InputGeometryFieldType: TypeAlias = Literal[attrs.CELL_AREA, attrs.TANGENT_ORIENTATION]

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -83,7 +81,7 @@ def __init__(
decomposition_info: definitions.DecompositionInfo,
backend: gtx_backend.Backend | None,
coordinates: gm.CoordinateDict,
extra_fields: dict[InputGeometryFieldType, gtx.Field],
extra_fields: gm.GeometryDict,
metadata: dict[str, model.FieldMetaData],
):
"""
Expand All @@ -103,7 +101,7 @@ def __init__(
self._grid = grid
self._decomposition_info = decomposition_info
self._attrs = metadata
self._geometry_type: base.GeometryType = grid.global_properties.geometry_type
self._geometry_type: base.GeometryType | None = grid.global_properties.geometry_type
self._edge_domain = h_grid.domain(dims.EdgeDim)
log.info(
f"initialized geometry for backend = '{self._backend_name()}' and grid = '{self._grid}'"
Expand Down Expand Up @@ -134,8 +132,8 @@ def __init__(
"latitude_of_edge_cell_neighbor_1": edge_orientation1_lat,
"longitude_of_edge_cell_neighbor_1": edge_orientation1_lon,
}
coodinate_provider = factory.PrecomputedFieldProvider(coordinates_)
self.register_provider(coodinate_provider)
coordinate_provider = factory.PrecomputedFieldProvider(coordinates_)
self.register_provider(coordinate_provider)

input_fields_provider = factory.PrecomputedFieldProvider(
{
Expand Down Expand Up @@ -176,7 +174,7 @@ def __init__(
self.register_provider(input_fields_provider)
self._register_computed_fields()

def _register_computed_fields(self):
def _register_computed_fields(self) -> None:
edge_length_provider = factory.ProgramFieldProvider(
func=stencils.compute_edge_length,
domain={
Expand Down Expand Up @@ -483,7 +481,12 @@ def _register_computed_fields(self):
self.register_provider(tangent_cell_wrapper)
cartesian_vertices = factory.EmbeddedFieldOperatorProvider(
func=math_helpers.geographical_to_cartesian_on_vertices.with_backend(self.backend),
domain=(dims.VertexDim,),
domain={
dims.VertexDim: (
h_grid.vertex_domain(h_grid.Zone.LOCAL),
h_grid.vertex_domain(h_grid.Zone.END),
)
},
fields={
attrs.VERTEX_X: attrs.VERTEX_X,
attrs.VERTEX_Y: attrs.VERTEX_Y,
Expand All @@ -497,7 +500,12 @@ def _register_computed_fields(self):
self.register_provider(cartesian_vertices)
cartesian_edge_centers = factory.EmbeddedFieldOperatorProvider(
func=math_helpers.geographical_to_cartesian_on_edges.with_backend(self.backend),
domain=(dims.EdgeDim,),
domain={
dims.EdgeDim: (
h_grid.edge_domain(h_grid.Zone.LOCAL),
h_grid.edge_domain(h_grid.Zone.END),
)
},
fields={
attrs.EDGE_CENTER_X: attrs.EDGE_CENTER_X,
attrs.EDGE_CENTER_Y: attrs.EDGE_CENTER_Y,
Expand All @@ -511,7 +519,12 @@ def _register_computed_fields(self):
self.register_provider(cartesian_edge_centers)
cartesian_cell_centers = factory.EmbeddedFieldOperatorProvider(
func=math_helpers.geographical_to_cartesian_on_cells.with_backend(self.backend),
domain=(dims.CellDim,),
domain={
dims.CellDim: (
h_grid.cell_domain(h_grid.Zone.LOCAL),
h_grid.cell_domain(h_grid.Zone.END),
)
},
fields={
attrs.CELL_CENTER_X: attrs.CELL_CENTER_X,
attrs.CELL_CENTER_Y: attrs.CELL_CENTER_Y,
Expand All @@ -524,7 +537,7 @@ def _register_computed_fields(self):
)
self.register_provider(cartesian_cell_centers)

def _inverse_field_provider(self, field_name: str):
def _inverse_field_provider(self, field_name: str) -> factory.FieldProvider:
meta = attrs.metadata_for_inverse(attrs.attrs[field_name])
name = meta["standard_name"]
self._attrs.update({name: meta})
Expand All @@ -541,38 +554,39 @@ def _inverse_field_provider(self, field_name: str):
)
return provider

def __repr__(self):
return f"{self.__class__.__name__} for geometry_type={self._geometry_type._name_} (grid={self._grid.id!r})"
def __repr__(self) -> str:
geometry_name = self._geometry_type._name_ if self._geometry_type else ""
return (
f"{self.__class__.__name__} for geometry_type={geometry_name} (grid={self._grid.id!r})"
)

@property
def metadata(self) -> dict[str, model.FieldMetaData]:
return self._attrs

@property
def backend(self) -> gtx_backend.Backend:
def backend(self) -> gtx_backend.Backend | None:
return self._backend

@property
def grid(self):
def grid(self) -> icon.IconGrid:
return self._grid

@property
def vertical_grid(self):
def vertical_grid(self) -> None:
return None


HorizontalD = TypeVar("HorizontalD", bound=gtx.Dimension)
SparseD = TypeVar("SparseD", bound=gtx.Dimension)


class SparseFieldProviderWrapper(factory.FieldProvider):
def __init__(
self,
field_provider: factory.ProgramFieldProvider,
target_dims: tuple[HorizontalD, SparseD],
target_dims: Sequence[gtx.Dimension],
fields: Sequence[str],
pairs: Sequence[tuple[str, ...]],
):
assert len(target_dims) == 2
assert target_dims[1].kind == gtx.DimensionKind.LOCAL
self._wrapped_provider = field_provider
self._fields = {name: None for name in fields}
self._func = functools.partial(as_sparse_field, target_dims)
Expand All @@ -583,11 +597,10 @@ def __call__(
field_name: str,
field_src: factory.FieldSource | None,
backend: gtx_backend.Backend | None,
grid: factory.GridProvider | None,
):
grid: factory.GridProvider,
) -> state_utils.GTXFieldType | None:
if not self._fields.get(field_name):
# get the fields from the wrapped provider

input_fields = []
for p in self._pairs:
t = tuple([self._wrapped_provider(name, field_src, backend, grid) for name in p])
Expand All @@ -610,10 +623,10 @@ def func(self) -> Callable:


def as_sparse_field(
target_dims: tuple[HorizontalD, SparseD],
data: Sequence[tuple[gtx.Field[gtx.Dims[HorizontalD], state_utils.ScalarType], ...]],
target_dims: tuple[gtx.Dimension, gtx.Dimension],
data: Sequence[tuple[gtx.Field[gtx.Dims[gtx.Dimension], state_utils.ScalarType], ...]],
backend: gtx_backend.Backend | None = None,
):
) -> Sequence[state_utils.GTXFieldType]:
assert len(target_dims) == 2
assert target_dims[0].kind == gtx.DimensionKind.HORIZONTAL
assert target_dims[1].kind == gtx.DimensionKind.LOCAL
Expand Down
1 change: 1 addition & 0 deletions model/common/src/icon4py/model/common/grid/grid_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __call__(self, array: data_alloc.NDArray):


CoordinateDict: TypeAlias = dict[gtx.Dimension, dict[Literal["lat", "lon"], gtx.Field]]
# TODO (halungge): use a TypeDict for that
GeometryDict: TypeAlias = dict[gridfile.GeometryName, gtx.Field]


Expand Down
5 changes: 5 additions & 0 deletions model/common/src/icon4py/model/common/grid/horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ def _domain(marker: Zone) -> Domain:
return _domain


vertex_domain = domain(dims.VertexDim)
edge_domain = domain(dims.EdgeDim)
cell_domain = domain(dims.CellDim)


def _validate(dim: gtx.Dimension, marker: Zone) -> bool:
return marker in _get_zones_for_dim(dim)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def __init__(

self._register_computed_fields()

def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__} on (grid={self._grid!r}) providing fields f{self.metadata.keys()}"

@property
def _sources(self) -> factory.FieldSource:
return factory.CompositeSource(self, (self._geometry,))

def _register_computed_fields(self):
def _register_computed_fields(self) -> None:
nudging_coefficients_for_edges = factory.ProgramFieldProvider(
func=nudgecoeffs.compute_nudgecoeffs.with_backend(None),
domain={
Expand Down Expand Up @@ -146,7 +146,7 @@ def _register_computed_fields(self):
)
self.register_provider(geofac_rot)

geofac_n2s = factory.NumpyFieldsProvider(
geofac_n2s = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_geofac_n2s, array_ns=self._xp),
fields=(attrs.GEOFAC_N2S,),
domain=(dims.CellDim, dims.C2E2CODim),
Expand All @@ -163,7 +163,7 @@ def _register_computed_fields(self):
)
self.register_provider(geofac_n2s)

geofac_grdiv = factory.NumpyFieldsProvider(
geofac_grdiv = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_geofac_grdiv, array_ns=self._xp),
fields=(attrs.GEOFAC_GRDIV,),
domain=(dims.EdgeDim, dims.E2C2EODim),
Expand All @@ -182,7 +182,7 @@ def _register_computed_fields(self):

self.register_provider(geofac_grdiv)

cell_average_weight = factory.NumpyFieldsProvider(
cell_average_weight = factory.NumpyDataProvider(
func=functools.partial(
interpolation_fields.compute_mass_conserving_bilinear_cell_average_weight,
array_ns=self._xp,
Expand All @@ -208,7 +208,7 @@ def _register_computed_fields(self):
)
self.register_provider(cell_average_weight)

c_lin_e = factory.NumpyFieldsProvider(
c_lin_e = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_c_lin_e, array_ns=self._xp),
fields=(attrs.C_LIN_E,),
domain=(dims.EdgeDim, dims.E2CDim),
Expand All @@ -225,7 +225,7 @@ def _register_computed_fields(self):
)
self.register_provider(c_lin_e)

geofac_grg = factory.NumpyFieldsProvider(
geofac_grg = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_geofac_grg, array_ns=self._xp),
fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y),
domain=(dims.CellDim, dims.C2E2CODim),
Expand All @@ -245,7 +245,7 @@ def _register_computed_fields(self):
)
self.register_provider(geofac_grg)

e_flx_avg = factory.NumpyFieldsProvider(
e_flx_avg = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_e_flx_avg, array_ns=self._xp),
fields=(attrs.E_FLX_AVG,),
domain=(dims.EdgeDim, dims.E2C2EODim),
Expand Down Expand Up @@ -274,7 +274,7 @@ def _register_computed_fields(self):
)
self.register_provider(e_flx_avg)

e_bln_c_s = factory.NumpyFieldsProvider(
e_bln_c_s = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_e_bln_c_s, array_ns=self._xp),
fields=(attrs.E_BLN_C_S,),
domain=(dims.CellDim, dims.C2EDim),
Expand All @@ -289,7 +289,7 @@ def _register_computed_fields(self):
)
self.register_provider(e_bln_c_s)

pos_on_tplane_e_x_y = factory.NumpyFieldsProvider(
pos_on_tplane_e_x_y = factory.NumpyDataProvider(
func=functools.partial(
interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp
),
Expand Down Expand Up @@ -318,7 +318,7 @@ def _register_computed_fields(self):
)
self.register_provider(pos_on_tplane_e_x_y)

cells_aw_verts = factory.NumpyFieldsProvider(
cells_aw_verts = factory.NumpyDataProvider(
func=functools.partial(interpolation_fields.compute_cells_aw_verts, array_ns=self._xp),
fields=(attrs.CELL_AW_VERTS,),
domain=(dims.VertexDim, dims.V2CDim),
Expand All @@ -341,7 +341,7 @@ def _register_computed_fields(self):
)
self.register_provider(cells_aw_verts)

rbf_vec_coeff_c = factory.NumpyFieldsProvider(
rbf_vec_coeff_c = factory.NumpyDataProvider(
func=functools.partial(rbf.compute_rbf_interpolation_coeffs_cell, array_ns=self._xp),
fields=(attrs.RBF_VEC_COEFF_C1, attrs.RBF_VEC_COEFF_C2),
domain=(dims.CellDim, dims.C2E2C2EDim),
Expand Down Expand Up @@ -369,7 +369,7 @@ def _register_computed_fields(self):
)
self.register_provider(rbf_vec_coeff_c)

rbf_vec_coeff_e = factory.NumpyFieldsProvider(
rbf_vec_coeff_e = factory.NumpyDataProvider(
func=functools.partial(rbf.compute_rbf_interpolation_coeffs_edge, array_ns=self._xp),
fields=(attrs.RBF_VEC_COEFF_E,),
domain=(dims.CellDim, dims.E2C2EDim),
Expand All @@ -396,7 +396,7 @@ def _register_computed_fields(self):
)
self.register_provider(rbf_vec_coeff_e)

rbf_vec_coeff_v = factory.NumpyFieldsProvider(
rbf_vec_coeff_v = factory.NumpyDataProvider(
func=functools.partial(rbf.compute_rbf_interpolation_coeffs_vertex, array_ns=self._xp),
fields=(attrs.RBF_VEC_COEFF_V1, attrs.RBF_VEC_COEFF_V2),
domain=(dims.VertexDim, dims.V2EDim),
Expand Down Expand Up @@ -429,13 +429,13 @@ def metadata(self) -> dict[str, model.FieldMetaData]:
return self._attrs

@property
def backend(self) -> gtx_backend.Backend:
def backend(self) -> gtx_backend.Backend | None:
return self._backend

@property
def grid(self):
def grid(self) -> icon.IconGrid:
return self._grid

@property
def vertical_grid(self):
def vertical_grid(self) -> None:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _compute_rbf_interpolation_coeffs(
scale_factor: ta.wpfloat,
horizontal_start: gtx.int32,
array_ns: ModuleType = np,
):
) -> tuple[data_alloc.NDArray, data_alloc.NDArray]:
rbf_offset_shape_full = rbf_offset.shape
rbf_offset = rbf_offset[horizontal_start:]
num_elements = rbf_offset.shape[0]
Expand Down Expand Up @@ -346,7 +346,7 @@ def index_offset(f):
rbf_vec_coeff_np[j][i + horizontal_start, valid_neighbors] = sla.cho_solve(
z_diag_np, rhs_np[j][i, valid_neighbors]
)
rbf_vec_coeff = [array_ns.asarray(x) for x in rbf_vec_coeff_np]
rbf_vec_coeff = tuple([array_ns.asarray(x) for x in rbf_vec_coeff_np])

# Normalize coefficients
for j in range(num_zonal_meridional_components):
Expand Down
Loading
Loading