diff --git a/model/common/src/icon4py/model/common/states/factory.py b/model/common/src/icon4py/model/common/states/factory.py index 9d8412ee4a..17284a14e8 100644 --- a/model/common/src/icon4py/model/common/states/factory.py +++ b/model/common/src/icon4py/model/common/states/factory.py @@ -450,28 +450,10 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension: field_domain = {_map_dim(dim): (0, _map_size(dim, grid)) for dim in self._dims} return {k: allocate(field_domain, dtype=dtype[k]) for k in self._fields} - # TODO(halungge): this can be simplified when completely disentangling vertical and horizontal grid. - # the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid - def _get_offset_providers(self, grid: icon_grid.IconGrid) -> dict[str, gtx.FieldOffset]: - offset_providers = {} - for dim in self._compute_domain: - if dim.kind == gtx.DimensionKind.HORIZONTAL: - horizontal_offsets = { - k: v - for k, v in grid.connectivities.items() - # TODO(halungge): review this workaround, as the fix should be available in the gt4py baseline - if isinstance(v, gtx.Connectivity) - and v.domain.dims[0].kind == gtx.DimensionKind.HORIZONTAL - } - offset_providers.update(horizontal_offsets) - if dim.kind == gtx.DimensionKind.VERTICAL: - vertical_offsets = { - k: v - for k, v in grid.connectivities.items() - if isinstance(v, gtx.Dimension) and v.kind == gtx.DimensionKind.VERTICAL - } - offset_providers.update(vertical_offsets) - return offset_providers + def _grid_connectivities( + self, grid: icon_grid.IconGrid + ) -> dict[str, gtx.Connectivity | gtx.Dimension]: + return grid.connectivities def _domain_args( self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid @@ -524,9 +506,9 @@ def _compute( deps = {k: factory.get(v) for k, v in self._dependencies.items()} deps.update(self._params) deps.update({k: self._fields[v] for k, v in self._output.items()}) - dims = self._domain_args(grid_provider.grid, grid_provider.vertical_grid) - offset_providers = self._get_offset_providers(grid_provider.grid) - deps.update(dims) + domain_bounds = self._domain_args(grid_provider.grid, grid_provider.vertical_grid) + deps.update(domain_bounds) + offset_providers = self._grid_connectivities(grid_provider.grid) self._func.with_backend(backend)(**deps, offset_provider=offset_providers) @property