Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def __init__(
interpolation_state: diffusion_states.DiffusionInterpolationState,
edge_params: grid_states.EdgeParams,
cell_params: grid_states.CellParams,
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
orchestration: bool = False,
exchange: decomposition.ExchangeRuntime | None = decomposition.single_node_default,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,7 @@ def __init__(
edge_geometry: grid_states.EdgeParams,
cell_geometry: grid_states.CellParams,
owner_mask: fa.CellField[bool],
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(),
):
self._exchange = exchange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ def __init__(
vertical_params: v_grid.VerticalGrid,
edge_params: grid_states.EdgeParams,
owner_mask: fa.CellField[bool],
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
):
self.grid: icon_grid.IconGrid = grid
self.metric_state: dycore_states.MetricStateNonHydro = metric_state
Expand Down
18 changes: 17 additions & 1 deletion model/common/src/icon4py/model/common/model_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

"""
This module defines the available backends for ICON4Py and provides utilities to work with them.

Backends ('BackendLike') are one of the following
- concrete GT4Py backends (instances of 'gtx_typing.Backend')
- 'None', which indicates embedded execution
- 'BackendDescriptor's, which are dictionaries containing the configuration parameters to construct a concrete backend
- 'DeviceType's, which indicate the target device for which the default backend should be constructed.
See 'model_options.customize_backend()' for details.

Note: A concrete GT4Py backend (an instance of 'gtx_typing.Backend') is also a 'gtx_typing.Allocator' as it implements
the 'FieldBufferAllocatorFactoryProtocol', while other 'BackendLike's do not.
Use 'get_allocator' to get the allocator for any 'BackendLike'.
"""

from typing import Any, Final, TypeAlias, TypeGuard

import gt4py.next as gtx
Expand Down Expand Up @@ -35,7 +51,7 @@ def is_backend_descriptor(

def get_allocator(
backend: BackendLike,
) -> gtx_typing.Backend:
) -> gtx_typing.Allocator:
if isinstance(backend, gtx_backend.Backend):
return backend
if backend is None:
Expand Down
23 changes: 14 additions & 9 deletions model/common/src/icon4py/model/common/model_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,18 @@ def get_options(program_name: str, **backend_descriptor: Any) -> model_backends.

def customize_backend(
program: gtx_typing.Program | gtx.typing.FieldOperator | None,
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
) -> gtx_typing.Backend | None:
"""
Customizes the backend according to the provided 'BackendLike' and program name.

In case a certain option is already provided in the 'BackendDesriptor', the customization
should not override it. Any option that doesn't apply to the backend factory should be ignored.

Note: The current customization mechanism is an adhoc solution that needs better design,
e.g. program specific options should not be gathered in this general module.
"""

program_name = program.__name__ if program is not None else ""
if backend is None or isinstance(backend, gtx_backend.Backend):
backend_name = backend.name if backend is not None else "embedded"
Expand All @@ -114,10 +121,7 @@ def customize_backend(

def setup_program(
program: gtx_typing.Program,
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
constant_args: dict[str, gtx.Field | gtx_typing.Scalar] | None = None,
variants: dict[str, list[gtx_typing.Scalar]] | None = None,
horizontal_sizes: dict[str, gtx.int32] | None = None,
Expand All @@ -127,7 +131,8 @@ def setup_program(
"""
This function processes arguments to the GT4Py program. It
- binds arguments that don't change during model run ('constant_args', 'horizontal_sizes', "vertical_sizes');
- inlines scalar arguments into the GT4Py program at compile-time (via GT4Py's 'compile').
- inlines scalar arguments into the GT4Py program at compile-time (via GT4Py's 'compile');
- instantiates a concrete backend according to the provided 'BackendLike'.
Args:
- backend: GT4Py backend,
- program: GT4Py program,
Expand Down
10 changes: 2 additions & 8 deletions model/common/tests/common/test_model_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ def test_custom_backend_device() -> None:
],
)
def test_setup_program_defaults(
backend: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend: model_backends.BackendLike,
) -> None:
testee = setup_program(backend=backend, program=program_return_field)
expected_backend = model_backends.make_custom_dace_backend(device=model_backends.CPU)
Expand Down Expand Up @@ -109,10 +106,7 @@ def test_setup_program_defaults(
],
)
def test_setup_program_specify_inputs(
backend_params: gtx_typing.Backend
| model_backends.DeviceType
| model_backends.BackendDescriptor
| None,
backend_params: model_backends.BackendLike,
expected_backend: gtx_typing.Backend | None,
) -> None:
testee = setup_program(backend=backend_params, program=program_return_field)
Expand Down
Loading