diff --git a/model/common/src/icon4py/model/common/grid/geometry.py b/model/common/src/icon4py/model/common/grid/geometry.py index b416f51ca8..b43bfb7184 100644 --- a/model/common/src/icon4py/model/common/grid/geometry.py +++ b/model/common/src/icon4py/model/common/grid/geometry.py @@ -8,7 +8,7 @@ import functools import logging from collections.abc import Callable, Mapping, Sequence -from typing import Any, Literal, TypeAlias, TypeVar +from typing import Any from gt4py import next as gtx from gt4py.next import backend as gtx_backend @@ -34,8 +34,6 @@ from icon4py.model.common.utils import data_allocation as data_alloc, device_utils -InputGeometryFieldType: TypeAlias = Literal[attrs.CELL_AREA, attrs.TANGENT_ORIENTATION] - log = logging.getLogger(__name__) @@ -83,7 +81,7 @@ def __init__( decomposition_info: definitions.DecompositionInfo, backend: gtx_backend.Backend | None, coordinates: gm.CoordinateDict, - extra_fields: dict[InputGeometryFieldType, gtx.Field], + extra_fields: gm.GeometryDict, metadata: dict[str, model.FieldMetaData], ): """ @@ -103,7 +101,7 @@ def __init__( self._grid = grid self._decomposition_info = decomposition_info self._attrs = metadata - self._geometry_type: base.GeometryType = grid.global_properties.geometry_type + self._geometry_type: base.GeometryType | None = grid.global_properties.geometry_type self._edge_domain = h_grid.domain(dims.EdgeDim) log.info( f"initialized geometry for backend = '{self._backend_name()}' and grid = '{self._grid}'" @@ -134,8 +132,8 @@ def __init__( "latitude_of_edge_cell_neighbor_1": edge_orientation1_lat, "longitude_of_edge_cell_neighbor_1": edge_orientation1_lon, } - coodinate_provider = factory.PrecomputedFieldProvider(coordinates_) - self.register_provider(coodinate_provider) + coordinate_provider = factory.PrecomputedFieldProvider(coordinates_) + self.register_provider(coordinate_provider) input_fields_provider = factory.PrecomputedFieldProvider( { @@ -176,7 +174,7 @@ def __init__( self.register_provider(input_fields_provider) self._register_computed_fields() - def _register_computed_fields(self): + def _register_computed_fields(self) -> None: edge_length_provider = factory.ProgramFieldProvider( func=stencils.compute_edge_length, domain={ @@ -483,7 +481,12 @@ def _register_computed_fields(self): self.register_provider(tangent_cell_wrapper) cartesian_vertices = factory.EmbeddedFieldOperatorProvider( func=math_helpers.geographical_to_cartesian_on_vertices.with_backend(self.backend), - domain=(dims.VertexDim,), + domain={ + dims.VertexDim: ( + h_grid.vertex_domain(h_grid.Zone.LOCAL), + h_grid.vertex_domain(h_grid.Zone.END), + ) + }, fields={ attrs.VERTEX_X: attrs.VERTEX_X, attrs.VERTEX_Y: attrs.VERTEX_Y, @@ -497,7 +500,12 @@ def _register_computed_fields(self): self.register_provider(cartesian_vertices) cartesian_edge_centers = factory.EmbeddedFieldOperatorProvider( func=math_helpers.geographical_to_cartesian_on_edges.with_backend(self.backend), - domain=(dims.EdgeDim,), + domain={ + dims.EdgeDim: ( + h_grid.edge_domain(h_grid.Zone.LOCAL), + h_grid.edge_domain(h_grid.Zone.END), + ) + }, fields={ attrs.EDGE_CENTER_X: attrs.EDGE_CENTER_X, attrs.EDGE_CENTER_Y: attrs.EDGE_CENTER_Y, @@ -511,7 +519,12 @@ def _register_computed_fields(self): self.register_provider(cartesian_edge_centers) cartesian_cell_centers = factory.EmbeddedFieldOperatorProvider( func=math_helpers.geographical_to_cartesian_on_cells.with_backend(self.backend), - domain=(dims.CellDim,), + domain={ + dims.CellDim: ( + h_grid.cell_domain(h_grid.Zone.LOCAL), + h_grid.cell_domain(h_grid.Zone.END), + ) + }, fields={ attrs.CELL_CENTER_X: attrs.CELL_CENTER_X, attrs.CELL_CENTER_Y: attrs.CELL_CENTER_Y, @@ -524,7 +537,7 @@ def _register_computed_fields(self): ) self.register_provider(cartesian_cell_centers) - def _inverse_field_provider(self, field_name: str): + def _inverse_field_provider(self, field_name: str) -> factory.FieldProvider: meta = attrs.metadata_for_inverse(attrs.attrs[field_name]) name = meta["standard_name"] self._attrs.update({name: meta}) @@ -541,38 +554,39 @@ def _inverse_field_provider(self, field_name: str): ) return provider - def __repr__(self): - return f"{self.__class__.__name__} for geometry_type={self._geometry_type._name_} (grid={self._grid.id!r})" + def __repr__(self) -> str: + geometry_name = self._geometry_type._name_ if self._geometry_type else "" + return ( + f"{self.__class__.__name__} for geometry_type={geometry_name} (grid={self._grid.id!r})" + ) @property def metadata(self) -> dict[str, model.FieldMetaData]: return self._attrs @property - def backend(self) -> gtx_backend.Backend: + def backend(self) -> gtx_backend.Backend | None: return self._backend @property - def grid(self): + def grid(self) -> icon.IconGrid: return self._grid @property - def vertical_grid(self): + def vertical_grid(self) -> None: return None -HorizontalD = TypeVar("HorizontalD", bound=gtx.Dimension) -SparseD = TypeVar("SparseD", bound=gtx.Dimension) - - class SparseFieldProviderWrapper(factory.FieldProvider): def __init__( self, field_provider: factory.ProgramFieldProvider, - target_dims: tuple[HorizontalD, SparseD], + target_dims: Sequence[gtx.Dimension], fields: Sequence[str], pairs: Sequence[tuple[str, ...]], ): + assert len(target_dims) == 2 + assert target_dims[1].kind == gtx.DimensionKind.LOCAL self._wrapped_provider = field_provider self._fields = {name: None for name in fields} self._func = functools.partial(as_sparse_field, target_dims) @@ -583,11 +597,10 @@ def __call__( field_name: str, field_src: factory.FieldSource | None, backend: gtx_backend.Backend | None, - grid: factory.GridProvider | None, - ): + grid: factory.GridProvider, + ) -> state_utils.GTXFieldType | None: if not self._fields.get(field_name): # get the fields from the wrapped provider - input_fields = [] for p in self._pairs: t = tuple([self._wrapped_provider(name, field_src, backend, grid) for name in p]) @@ -610,10 +623,10 @@ def func(self) -> Callable: def as_sparse_field( - target_dims: tuple[HorizontalD, SparseD], - data: Sequence[tuple[gtx.Field[gtx.Dims[HorizontalD], state_utils.ScalarType], ...]], + target_dims: tuple[gtx.Dimension, gtx.Dimension], + data: Sequence[tuple[gtx.Field[gtx.Dims[gtx.Dimension], state_utils.ScalarType], ...]], backend: gtx_backend.Backend | None = None, -): +) -> Sequence[state_utils.GTXFieldType]: assert len(target_dims) == 2 assert target_dims[0].kind == gtx.DimensionKind.HORIZONTAL assert target_dims[1].kind == gtx.DimensionKind.LOCAL diff --git a/model/common/src/icon4py/model/common/grid/grid_manager.py b/model/common/src/icon4py/model/common/grid/grid_manager.py index 924427b4b2..a5730a65d4 100644 --- a/model/common/src/icon4py/model/common/grid/grid_manager.py +++ b/model/common/src/icon4py/model/common/grid/grid_manager.py @@ -64,6 +64,7 @@ def __call__(self, array: data_alloc.NDArray): CoordinateDict: TypeAlias = dict[gtx.Dimension, dict[Literal["lat", "lon"], gtx.Field]] +# TODO (halungge): use a TypeDict for that GeometryDict: TypeAlias = dict[gridfile.GeometryName, gtx.Field] diff --git a/model/common/src/icon4py/model/common/grid/horizontal.py b/model/common/src/icon4py/model/common/grid/horizontal.py index 122e7f032b..50d919166d 100644 --- a/model/common/src/icon4py/model/common/grid/horizontal.py +++ b/model/common/src/icon4py/model/common/grid/horizontal.py @@ -372,6 +372,11 @@ def _domain(marker: Zone) -> Domain: return _domain +vertex_domain = domain(dims.VertexDim) +edge_domain = domain(dims.EdgeDim) +cell_domain = domain(dims.CellDim) + + def _validate(dim: gtx.Dimension, marker: Zone) -> bool: return marker in _get_zones_for_dim(dim) diff --git a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py index 6ede0d6a1f..b00ab301bd 100644 --- a/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py +++ b/model/common/src/icon4py/model/common/interpolation/interpolation_factory.py @@ -90,14 +90,14 @@ def __init__( self._register_computed_fields() - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__} on (grid={self._grid!r}) providing fields f{self.metadata.keys()}" @property def _sources(self) -> factory.FieldSource: return factory.CompositeSource(self, (self._geometry,)) - def _register_computed_fields(self): + def _register_computed_fields(self) -> None: nudging_coefficients_for_edges = factory.ProgramFieldProvider( func=nudgecoeffs.compute_nudgecoeffs.with_backend(None), domain={ @@ -146,7 +146,7 @@ def _register_computed_fields(self): ) self.register_provider(geofac_rot) - geofac_n2s = factory.NumpyFieldsProvider( + geofac_n2s = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_geofac_n2s, array_ns=self._xp), fields=(attrs.GEOFAC_N2S,), domain=(dims.CellDim, dims.C2E2CODim), @@ -163,7 +163,7 @@ def _register_computed_fields(self): ) self.register_provider(geofac_n2s) - geofac_grdiv = factory.NumpyFieldsProvider( + geofac_grdiv = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_geofac_grdiv, array_ns=self._xp), fields=(attrs.GEOFAC_GRDIV,), domain=(dims.EdgeDim, dims.E2C2EODim), @@ -182,7 +182,7 @@ def _register_computed_fields(self): self.register_provider(geofac_grdiv) - cell_average_weight = factory.NumpyFieldsProvider( + cell_average_weight = factory.NumpyDataProvider( func=functools.partial( interpolation_fields.compute_mass_conserving_bilinear_cell_average_weight, array_ns=self._xp, @@ -208,7 +208,7 @@ def _register_computed_fields(self): ) self.register_provider(cell_average_weight) - c_lin_e = factory.NumpyFieldsProvider( + c_lin_e = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_c_lin_e, array_ns=self._xp), fields=(attrs.C_LIN_E,), domain=(dims.EdgeDim, dims.E2CDim), @@ -225,7 +225,7 @@ def _register_computed_fields(self): ) self.register_provider(c_lin_e) - geofac_grg = factory.NumpyFieldsProvider( + geofac_grg = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_geofac_grg, array_ns=self._xp), fields=(attrs.GEOFAC_GRG_X, attrs.GEOFAC_GRG_Y), domain=(dims.CellDim, dims.C2E2CODim), @@ -245,7 +245,7 @@ def _register_computed_fields(self): ) self.register_provider(geofac_grg) - e_flx_avg = factory.NumpyFieldsProvider( + e_flx_avg = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_e_flx_avg, array_ns=self._xp), fields=(attrs.E_FLX_AVG,), domain=(dims.EdgeDim, dims.E2C2EODim), @@ -274,7 +274,7 @@ def _register_computed_fields(self): ) self.register_provider(e_flx_avg) - e_bln_c_s = factory.NumpyFieldsProvider( + e_bln_c_s = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_e_bln_c_s, array_ns=self._xp), fields=(attrs.E_BLN_C_S,), domain=(dims.CellDim, dims.C2EDim), @@ -289,7 +289,7 @@ def _register_computed_fields(self): ) self.register_provider(e_bln_c_s) - pos_on_tplane_e_x_y = factory.NumpyFieldsProvider( + pos_on_tplane_e_x_y = factory.NumpyDataProvider( func=functools.partial( interpolation_fields.compute_pos_on_tplane_e_x_y, array_ns=self._xp ), @@ -318,7 +318,7 @@ def _register_computed_fields(self): ) self.register_provider(pos_on_tplane_e_x_y) - cells_aw_verts = factory.NumpyFieldsProvider( + cells_aw_verts = factory.NumpyDataProvider( func=functools.partial(interpolation_fields.compute_cells_aw_verts, array_ns=self._xp), fields=(attrs.CELL_AW_VERTS,), domain=(dims.VertexDim, dims.V2CDim), @@ -341,7 +341,7 @@ def _register_computed_fields(self): ) self.register_provider(cells_aw_verts) - rbf_vec_coeff_c = factory.NumpyFieldsProvider( + rbf_vec_coeff_c = factory.NumpyDataProvider( func=functools.partial(rbf.compute_rbf_interpolation_coeffs_cell, array_ns=self._xp), fields=(attrs.RBF_VEC_COEFF_C1, attrs.RBF_VEC_COEFF_C2), domain=(dims.CellDim, dims.C2E2C2EDim), @@ -369,7 +369,7 @@ def _register_computed_fields(self): ) self.register_provider(rbf_vec_coeff_c) - rbf_vec_coeff_e = factory.NumpyFieldsProvider( + rbf_vec_coeff_e = factory.NumpyDataProvider( func=functools.partial(rbf.compute_rbf_interpolation_coeffs_edge, array_ns=self._xp), fields=(attrs.RBF_VEC_COEFF_E,), domain=(dims.CellDim, dims.E2C2EDim), @@ -396,7 +396,7 @@ def _register_computed_fields(self): ) self.register_provider(rbf_vec_coeff_e) - rbf_vec_coeff_v = factory.NumpyFieldsProvider( + rbf_vec_coeff_v = factory.NumpyDataProvider( func=functools.partial(rbf.compute_rbf_interpolation_coeffs_vertex, array_ns=self._xp), fields=(attrs.RBF_VEC_COEFF_V1, attrs.RBF_VEC_COEFF_V2), domain=(dims.VertexDim, dims.V2EDim), @@ -429,13 +429,13 @@ def metadata(self) -> dict[str, model.FieldMetaData]: return self._attrs @property - def backend(self) -> gtx_backend.Backend: + def backend(self) -> gtx_backend.Backend | None: return self._backend @property - def grid(self): + def grid(self) -> icon.IconGrid: return self._grid @property - def vertical_grid(self): + def vertical_grid(self) -> None: return None diff --git a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py index c0bd42ad89..b04fc38616 100644 --- a/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py +++ b/model/common/src/icon4py/model/common/interpolation/rbf_interpolation.py @@ -227,7 +227,7 @@ def _compute_rbf_interpolation_coeffs( scale_factor: ta.wpfloat, horizontal_start: gtx.int32, array_ns: ModuleType = np, -): +) -> tuple[data_alloc.NDArray, data_alloc.NDArray]: rbf_offset_shape_full = rbf_offset.shape rbf_offset = rbf_offset[horizontal_start:] num_elements = rbf_offset.shape[0] @@ -346,7 +346,7 @@ def index_offset(f): rbf_vec_coeff_np[j][i + horizontal_start, valid_neighbors] = sla.cho_solve( z_diag_np, rhs_np[j][i, valid_neighbors] ) - rbf_vec_coeff = [array_ns.asarray(x) for x in rbf_vec_coeff_np] + rbf_vec_coeff = tuple([array_ns.asarray(x) for x in rbf_vec_coeff_np]) # Normalize coefficients for j in range(num_zonal_meridional_components): diff --git a/model/common/src/icon4py/model/common/metrics/metrics_factory.py b/model/common/src/icon4py/model/common/metrics/metrics_factory.py index 96831758ee..d6504b9825 100644 --- a/model/common/src/icon4py/model/common/metrics/metrics_factory.py +++ b/model/common/src/icon4py/model/common/metrics/metrics_factory.py @@ -81,8 +81,7 @@ def __init__( f"initialized metrics factory for backend = '{self._backend_name()}' and grid = '{self._grid}'" ) log.debug(f"using array_ns {self._xp} ") - vct_a = self._vertical_grid.vct_a - vct_a_1 = vct_a.asnumpy()[0] + vct_a_1 = self._vertical_grid.interface_physical_height.ndarray[0].item() self._config = { "divdamp_trans_start": 12500.0, "divdamp_trans_end": 17500.0, @@ -114,7 +113,7 @@ def __init__( factory.PrecomputedFieldProvider( { "topography": topography, - "vct_a": vct_a, + "vct_a": self._vertical_grid.interface_physical_height, "c_refin_ctrl": c_refin_ctrl, "e_refin_ctrl": e_refin_ctrl, "e_owner_mask": e_owner_mask, @@ -126,24 +125,21 @@ def __init__( ) self._register_computed_fields() - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__} on (grid={self._grid!r}) providing fields f{self.metadata.keys()}" @property def _sources(self) -> factory.FieldSource: return factory.CompositeSource(self, (self._geometry, self._interpolation_source)) - def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] - vertical_coordinates_on_half_levels = factory.NumpyFieldsProvider( + def _register_computed_fields(self) -> None: # noqa: PLR0915 [too-many-statements] + vertical_coordinates_on_half_levels = factory.NumpyDataProvider( func=functools.partial( v_grid.compute_vertical_coordinate, array_ns=self._xp, ), fields=(attrs.CELL_HEIGHT_ON_HALF_LEVEL,), - domain={ - dims.CellDim: (0, cell_domain(h_grid.Zone.END)), - dims.KDim: (vertical_domain(v_grid.Zone.TOP), vertical_domain(v_grid.Zone.BOTTOM)), - }, + domain=(dims.CellDim, dims.KDim), deps={ "vct_a": "vct_a", "topography": "topography", @@ -458,7 +454,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] ) self.register_provider(compute_ddxt_z_full) - compute_exner_w_implicit_weight_parameter_np = factory.NumpyFieldsProvider( + compute_exner_w_implicit_weight_parameter_np = factory.NumpyDataProvider( func=functools.partial(mf.compute_exner_w_implicit_weight_parameter, array_ns=self._xp), domain=(dims.CellDim,), connectivities={"c2e": dims.C2EDim}, @@ -578,7 +574,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] fields={"flat_idx": attrs.FLAT_EDGE_INDEX}, ) self.register_provider(compute_flat_edge_idx) - max_flat_index_provider = factory.NumpyFieldsProvider( + max_flat_index_provider = factory.NumpyDataProvider( func=functools.partial(mf.compute_max_index, array_ns=self._xp), domain=(dims.EdgeDim,), fields=(attrs.FLAT_IDX_MAX,), @@ -650,13 +646,13 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] }, fields={attrs.HORIZONTAL_MASK_FOR_3D_DIVDAMP: attrs.HORIZONTAL_MASK_FOR_3D_DIVDAMP}, params={ - "grf_nudge_start_e": refinement.get_nudging_refinement_value(dims.EdgeDim), - "grf_nudgezone_width": gtx.int32(refinement.DEFAULT_GRF_NUDGEZONE_WIDTH), + "grf_nudge_start_e": refinement.get_nudging_refinement_value(dims.EdgeDim), # type: ignore [attr-defined] + "grf_nudgezone_width": gtx.int32(refinement.DEFAULT_GRF_NUDGEZONE_WIDTH), # type: ignore [attr-defined] }, ) self.register_provider(compute_horizontal_mask_for_3d_divdamp) - compute_zdiff_gradp_dsl_np = factory.NumpyFieldsProvider( + compute_zdiff_gradp_dsl_np = factory.NumpyDataProvider( func=functools.partial( compute_zdiff_gradp_dsl.compute_zdiff_gradp_dsl, array_ns=self._xp ), @@ -682,7 +678,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] ) self.register_provider(compute_zdiff_gradp_dsl_np) - coeff_gradekin = factory.NumpyFieldsProvider( + coeff_gradekin = factory.NumpyDataProvider( func=functools.partial( compute_coeff_gradekin.compute_coeff_gradekin, array_ns=self._xp ), @@ -701,7 +697,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] ) self.register_provider(coeff_gradekin) - compute_wgtfacq_c = factory.NumpyFieldsProvider( + compute_wgtfacq_c = factory.NumpyDataProvider( func=functools.partial(weight_factors.compute_wgtfacq_c_dsl, array_ns=self._xp), domain=(dims.CellDim, dims.KDim), fields=(attrs.WGTFACQ_C,), @@ -711,7 +707,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] self.register_provider(compute_wgtfacq_c) - compute_wgtfacq_e = factory.NumpyFieldsProvider( + compute_wgtfacq_e = factory.NumpyDataProvider( func=functools.partial(weight_factors.compute_wgtfacq_e_dsl, array_ns=self._xp), deps={ "z_ifc": attrs.CELL_HEIGHT_ON_HALF_LEVEL, @@ -767,7 +763,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] ) self.register_provider(compute_weighted_cell_neighbor_sum) - compute_max_nbhgt = factory.NumpyFieldsProvider( + compute_max_nbhgt = factory.NumpyDataProvider( func=functools.partial( compute_diffusion_metrics.compute_max_nbhgt_array_ns, array_ns=self._xp ), @@ -783,7 +779,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] ) self.register_provider(compute_max_nbhgt) - compute_diffusion_mask_and_coef = factory.NumpyFieldsProvider( + compute_diffusion_mask_and_coef = factory.NumpyDataProvider( func=functools.partial( compute_diffusion_metrics.compute_diffusion_mask_and_coef, array_ns=self._xp ), @@ -812,7 +808,7 @@ def _register_computed_fields(self): # noqa: PLR0915 [too-many-statements] self.register_provider(compute_diffusion_mask_and_coef) - compute_diffusion_intcoef_and_vertoffset = factory.NumpyFieldsProvider( + compute_diffusion_intcoef_and_vertoffset = factory.NumpyDataProvider( func=functools.partial( compute_diffusion_metrics.compute_diffusion_intcoef_and_vertoffset, array_ns=self._xp, @@ -847,13 +843,13 @@ def metadata(self) -> dict[str, model.FieldMetaData]: return self._attrs @property - def backend(self) -> gtx_backend.Backend: + def backend(self) -> gtx_backend.Backend | None: return self._backend @property - def grid(self): + def grid(self) -> icon.IconGrid: return self._grid @property - def vertical_grid(self): + def vertical_grid(self) -> v_grid.VerticalGrid: return self._vertical_grid diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 21ae36b41b..f9e9a3c3f4 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -39,6 +39,8 @@ """ +from __future__ import annotations + import collections import enum import functools @@ -47,7 +49,7 @@ import typing from collections.abc import Callable, Mapping, MutableMapping, Sequence from types import ModuleType -from typing import Any, Literal, Optional, Protocol, TypeVar, overload +from typing import Any, Literal, Protocol, TypeVar, overload import gt4py.next as gtx import gt4py.next.typing as gtx_typing @@ -70,7 +72,7 @@ class GridProvider(Protocol): @property - def grid(self) -> icon_grid.IconGrid | None: ... + def grid(self) -> icon_grid.IconGrid: ... @property def vertical_grid(self) -> v_grid.VerticalGrid | None: ... @@ -93,16 +95,18 @@ class FieldProvider(Protocol): def __call__( self, field_name: str, - field_src: Optional["FieldSource"], + field_src: FieldSource, backend: gtx_typing.Backend | None, - grid: GridProvider | None, - ) -> state_utils.FieldType: ... + grid: GridProvider, + ) -> state_utils.GTXFieldType | state_utils.ScalarType: ... @property def dependencies(self) -> Sequence[str]: ... @property - def fields(self) -> Mapping[str, Any]: ... + def fields( + self, + ) -> Mapping[str, state_utils.FieldType | state_utils.ScalarType]: ... @property def func(self) -> Callable: ... @@ -124,7 +128,7 @@ class FieldSource(GridProvider, Protocol): _providers: MutableMapping[str, FieldProvider] = {} # noqa: RUF012 instance variable @property - def _sources(self) -> "FieldSource": + def _sources(self) -> FieldSource: return self @property @@ -132,11 +136,13 @@ def metadata(self) -> MutableMapping[str, model.FieldMetaData]: """Returns metadata for the fields that this field source provides.""" ... - # TODO @halungge: this is the target Backend: not necessarily the one that the field is computed and - # there are fields which need to be computed on a specific backend, which can be different from the - # general run backend @property - def backend(self) -> gtx_typing.Backend | None: ... + def backend(self) -> gtx_typing.Backend | None: + """Target backend: this is the backend that the field should be produced for when requested from the source. + The field computation might + be done on a different backend, as there are FieldOperators that require a specific backend (mostly embedded) + to be used.""" + ... def _backend_name(self) -> str: return "embedded" if self.backend is None else self.backend.name @@ -156,7 +162,7 @@ def get( def get( self, field_name: str, type_: RetrievalType = RetrievalType.FIELD - ) -> state_utils.GTXFieldType | xa.DataArray | model.FieldMetaData: + ) -> state_utils.GTXFieldType | xa.DataArray | model.FieldMetaData | state_utils.ScalarType: """ Get a field or its metadata from the factory. @@ -191,10 +197,10 @@ def get( case _: raise ValueError(f"Invalid retrieval type {type_}") - def _provided_by_source(self, name): + def _provided_by_source(self, name) -> str: return name in self._sources._providers or name in self._sources.metadata - def register_provider(self, provider: FieldProvider): + def register_provider(self, provider: FieldProvider) -> None: # dependencies must be provider by this field source or registered in sources for dependency in provider.dependencies: if not (dependency in self._providers or self._provided_by_source(dependency)): @@ -219,7 +225,7 @@ def metadata(self) -> MutableMapping[str, model.FieldMetaData]: return self._metadata @property - def backend(self) -> gtx_typing.Backend: + def backend(self) -> gtx_typing.Backend | None: return self._backend @property @@ -235,7 +241,7 @@ class PrecomputedFieldProvider(FieldProvider): """Simple FieldProvider that does not do any computation but gets its fields at construction and returns it upon provider.get(field_name).""" - def __init__(self, fields: dict[str, state_utils.FieldType]): + def __init__(self, fields: dict[str, state_utils.GTXFieldType]): self._fields = fields @property @@ -244,11 +250,11 @@ def dependencies(self) -> Sequence[str]: def __call__( self, field_name: str, field_src=None, backend=None, grid=None - ) -> state_utils.FieldType: + ) -> state_utils.GTXFieldType: return self.fields[field_name] @property - def fields(self) -> Mapping[str, state_utils.FieldType]: + def fields(self) -> Mapping[str, state_utils.GTXFieldType]: return self._fields @property @@ -270,14 +276,16 @@ class EmbeddedFieldOperatorProvider(FieldProvider): def __init__( self, func: gtx_typing.FieldOperator, - domain: dict[gtx.Dimension, tuple[h_grid.Domain, h_grid.Domain]], + domain: dict[gtx.Dimension, tuple[DomainType, DomainType]] | tuple[gtx.Dimension, ...], fields: dict[str, str], # keyword arg to (field_operator, field_name) deps: dict[str, str], # keyword arg to (field_operator, field_name) need: src params: dict[str, state_utils.ScalarType] | None = None, # keyword arg to (field_operator, field_name) ): self._func = func - self._dims = domain + self._dims: ( + dict[gtx.Dimension, tuple[DomainType, DomainType]] | tuple[gtx.Dimension, ...] + ) = domain self._dependencies = deps self._output = fields self._params = {} if params is None else params @@ -512,7 +520,7 @@ def _domain_args( def __call__( self, field_name: str, - factory: FieldSource, + factory: FieldSource | None, backend: gtx_typing.Backend | None, grid_provider: GridProvider, ): @@ -554,7 +562,7 @@ def dependencies(self) -> Sequence[str]: return list(self._dependencies.values()) -class NumpyFieldsProvider(FieldProvider): +class NumpyDataProvider(FieldProvider): """ Computes a field defined by a numpy function. @@ -580,7 +588,9 @@ def __init__( ): self._func = func self._dims = domain - self._fields: dict[str, state_utils.FieldType | None] = {name: None for name in fields} + self._fields: dict[str, state_utils.ScalarType | state_utils.FieldType | None] = { + name: None for name in fields + } self._dependencies = deps self._connectivities = connectivities if connectivities is not None else {} self._params = params if params is not None else {} @@ -611,13 +621,18 @@ def _compute( args.update(offsets) args.update(self._params) results = self._func(**args) - # TODO(): can the order of return values be checked? - results = (results,) if isinstance(results, data_alloc.NDArray) else results + # convert to tuple + results = (results,) if not isinstance(results, tuple) else results self._fields = { - k: gtx.as_field(tuple(self._dims), results[i], allocator=backend) + k: self._as_field(backend, results[i]) if self._dims else results[i] for i, k in enumerate(self.fields) } + def _as_field( + self, backend: gtx_typing.Backend | None, value: data_alloc.NDArray + ) -> state_utils.GTXFieldType: + return gtx.as_field(tuple(self._dims), value, allocator=backend) + def _validate_dependencies(self) -> None: # TODO(egparedes): dealing with type annotations at run-time is error prone # and requires robust utility functions. This snippet should use a better diff --git a/model/common/src/icon4py/model/common/states/model.py b/model/common/src/icon4py/model/common/states/model.py index 5258c0dea7..508dd48c78 100644 --- a/model/common/src/icon4py/model/common/states/model.py +++ b/model/common/src/icon4py/model/common/states/model.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses import functools +from collections.abc import Sequence from typing import Literal, Protocol, TypeAlias, TypedDict, runtime_checkable import gt4py._core.definitions as gt_coredefs @@ -19,7 +20,6 @@ """Contains type definitions used for the model`s state representation.""" DimensionNames: TypeAlias = Literal["cell", "edge", "vertex"] -DimensionT: TypeAlias = gtx.Dimension | DimensionNames BufferT: TypeAlias = np_t.ArrayLike | gtx.Field DTypeT: TypeAlias = ta.wpfloat | ta.vpfloat | gtx.int32 | gtx.int64 | gtx.float32 | gtx.float64 @@ -30,7 +30,7 @@ class OptionalMetaData(TypedDict, total=False): #: we might not have this one for all fields. But it is useful to have it for tractability with ICON icon_var_name: str # TODO(halungge): dims should probably be required? - dims: tuple[DimensionT, ...] + dims: Sequence[gtx.Dimension] dtype: ta.wpfloat | ta.vpfloat | gtx.int32 | gtx.int64 | gtx.float32 | gtx.float64 diff --git a/model/common/src/icon4py/model/common/states/utils.py b/model/common/src/icon4py/model/common/states/utils.py index 065a6f409d..5a6b8e1b0f 100644 --- a/model/common/src/icon4py/model/common/states/utils.py +++ b/model/common/src/icon4py/model/common/states/utils.py @@ -5,18 +5,17 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import MutableMapping from typing import TypeAlias, TypeVar import gt4py.next as gtx import xarray as xa from gt4py.next.common import DimsT -from icon4py.model.common import dimension as dims, type_alias as ta +from icon4py.model.common import type_alias as ta from icon4py.model.common.utils import data_allocation as data_alloc -DimT = TypeVar("DimT", dims.KDim, dims.KHalfDim, dims.CellDim, dims.EdgeDim, dims.VertexDim) - FloatType: TypeAlias = ta.wpfloat | ta.vpfloat | float IntegerType: TypeAlias = gtx.int32 | gtx.int64 | int ScalarType: TypeAlias = FloatType | bool | IntegerType @@ -27,6 +26,6 @@ FieldType: TypeAlias = gtx.Field[DimsT, T] | data_alloc.NDArray -def to_data_array(field: FieldType, attrs: dict): +def to_data_array(field: FieldType, attrs: MutableMapping[str, ...]): data = data_alloc.as_numpy(field) return xa.DataArray(data, attrs=attrs) diff --git a/model/common/tests/common/states/unit_tests/test_factory.py b/model/common/tests/common/states/unit_tests/test_factory.py index 89b33829d1..1bdc21e955 100644 --- a/model/common/tests/common/states/unit_tests/test_factory.py +++ b/model/common/tests/common/states/unit_tests/test_factory.py @@ -7,9 +7,12 @@ # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations +import functools +from types import ModuleType from typing import TYPE_CHECKING import gt4py.next as gtx +import numpy as np import pytest from icon4py.model.common import dimension as dims, utils as common_utils @@ -17,7 +20,7 @@ from icon4py.model.common.math import helpers as math_helpers from icon4py.model.common.states import factory, model, utils as state_utils from icon4py.model.common.utils import data_allocation as data_alloc -from icon4py.model.testing import definitions +from icon4py.model.testing import definitions, serialbox from icon4py.model.testing.fixtures.datatest import ( backend, data_provider, @@ -44,7 +47,7 @@ class SimpleFieldSource(factory.FieldSource): def __init__( self, - data_: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]], + data_: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]], backend: gtx_typing.Backend | None, grid: icon.IconGrid, vertical_grid: v_grid.VerticalGrid | None = None, @@ -103,7 +106,7 @@ def cell_coordinate_source( grid = grid_savepoint.construct_icon_grid(backend=backend) lat = grid_savepoint.lat(dims.CellDim) lon = grid_savepoint.lon(dims.CellDim) - data: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]] = { + data: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]] = { "lat": (lat, {"standard_name": "lat", "units": ""}), "lon": (lon, {"standard_name": "lon", "units": ""}), "x": ( @@ -136,7 +139,7 @@ def height_coordinate_source( z_ifc = metrics_savepoint.z_ifc() vct_a = grid_savepoint.vct_a() vct_b = grid_savepoint.vct_b() - data: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]] = { + data: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]] = { "height_coordinate": (z_ifc, {"standard_name": "height_coordinate", "units": ""}) } vertical_grid = v_grid.VerticalGrid( @@ -212,7 +215,7 @@ def test_composite_field_source_contains_all_metadata( grid = cell_coordinate_source.grid foo = data_alloc.random_field(grid, dims.CellDim, dims.KDim) bar = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim) - data: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]] = { + data: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]] = { "foo": (foo, {"standard_name": "foo", "units": ""}), "bar": (bar, {"standard_name": "bar", "units": ""}), } @@ -237,7 +240,7 @@ def test_composite_field_source_get_all_fields( grid = cell_coordinate_source.grid foo = data_alloc.random_field(grid, dims.CellDim, dims.KDim) bar = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim) - data: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]] = { + data: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]] = { "foo": (foo, {"standard_name": "foo", "units": ""}), "bar": (bar, {"standard_name": "bar", "units": ""}), } @@ -274,7 +277,7 @@ def test_composite_field_source_raises_upon_get_unknown_field( grid = cell_coordinate_source.grid foo = data_alloc.random_field(grid, dims.CellDim, dims.KDim) bar = data_alloc.random_field(grid, dims.EdgeDim, dims.KDim) - data: dict[str, tuple[state_utils.FieldType, model.FieldMetaData]] = { + data: dict[str, tuple[state_utils.GTXFieldType, model.FieldMetaData]] = { "foo": (foo, {"standard_name": "foo", "units": ""}), "bar": (bar, {"standard_name": "bar", "units": ""}), } @@ -286,3 +289,27 @@ def test_composite_field_source_raises_upon_get_unknown_field( with pytest.raises(ValueError) as err: composite.get("alice") assert "not provided by source " in err.value # type: ignore[operator] + + +def reduce_scalar_min(ar: data_alloc.NDArray, xp: ModuleType) -> gtx.float: + return xp.min(ar).item() + + +@pytest.mark.datatest +def test_compute_scalar_value_from_numpy_provider( + height_coordinate_source: factory.FieldSource, + metrics_savepoint: serialbox.MetricSavepoint, + backend: gtx_typing.Backend, +) -> None: + value_ref = np.min(metrics_savepoint.z_ifc()) + sample_func = functools.partial(reduce_scalar_min, xp=data_alloc.import_array_ns(backend)) + provider = factory.NumpyDataProvider( + func=sample_func, + deps={"ar": "height_coordinate"}, + domain=(), + fields=("minimal_height",), + ) + height_coordinate_source.register_provider(provider) + value = height_coordinate_source.get("minimal_height", factory.RetrievalType.FIELD) + assert np.isscalar(value) + assert value_ref == value