Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions model/common/src/icon4py/model/common/states/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,28 +450,10 @@ def _map_dim(dim: gtx.Dimension) -> gtx.Dimension:
field_domain = {_map_dim(dim): (0, _map_size(dim, grid)) for dim in self._dims}
return {k: allocate(field_domain, dtype=dtype[k]) for k in self._fields}

# TODO(halungge): this can be simplified when completely disentangling vertical and horizontal grid.
# the IconGrid should then only contain horizontal connectivities and no longer any Koff which should be moved to the VerticalGrid
def _get_offset_providers(self, grid: icon_grid.IconGrid) -> dict[str, gtx.FieldOffset]:
offset_providers = {}
for dim in self._compute_domain:
if dim.kind == gtx.DimensionKind.HORIZONTAL:
horizontal_offsets = {
k: v
for k, v in grid.connectivities.items()
# TODO(halungge): review this workaround, as the fix should be available in the gt4py baseline
if isinstance(v, gtx.Connectivity)
and v.domain.dims[0].kind == gtx.DimensionKind.HORIZONTAL
}
offset_providers.update(horizontal_offsets)
if dim.kind == gtx.DimensionKind.VERTICAL:
vertical_offsets = {
k: v
for k, v in grid.connectivities.items()
if isinstance(v, gtx.Dimension) and v.kind == gtx.DimensionKind.VERTICAL
}
offset_providers.update(vertical_offsets)
return offset_providers
def _grid_connectivities(
self, grid: icon_grid.IconGrid
) -> dict[str, gtx.Connectivity | gtx.Dimension]:
return grid.connectivities
Comment on lines +453 to +456
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this function should disappear in the end?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes.


def _domain_args(
self, grid: icon_grid.IconGrid, vertical_grid: v_grid.VerticalGrid
Expand Down Expand Up @@ -524,9 +506,9 @@ def _compute(
deps = {k: factory.get(v) for k, v in self._dependencies.items()}
deps.update(self._params)
deps.update({k: self._fields[v] for k, v in self._output.items()})
dims = self._domain_args(grid_provider.grid, grid_provider.vertical_grid)
offset_providers = self._get_offset_providers(grid_provider.grid)
deps.update(dims)
domain_bounds = self._domain_args(grid_provider.grid, grid_provider.vertical_grid)
deps.update(domain_bounds)
offset_providers = self._grid_connectivities(grid_provider.grid)
self._func.with_backend(backend)(**deps, offset_provider=offset_providers)

@property
Expand Down