diff --git a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py index 2b742bbbc4..83c7fd59bb 100644 --- a/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py +++ b/model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py @@ -397,10 +397,7 @@ def __init__( interpolation_state: diffusion_states.DiffusionInterpolationState, edge_params: grid_states.EdgeParams, cell_params: grid_states.CellParams, - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, orchestration: bool = False, exchange: decomposition.ExchangeRuntime | None = decomposition.single_node_default, ): diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py index 21a177e86f..cfa5fbe0f1 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/solve_nonhydro.py @@ -366,10 +366,7 @@ def __init__( edge_geometry: grid_states.EdgeParams, cell_geometry: grid_states.CellParams, owner_mask: fa.CellField[bool], - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, exchange: decomposition.ExchangeRuntime = decomposition.SingleNodeExchange(), ): self._exchange = exchange diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py index 29cc3cb661..ca14dc7936 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/velocity_advection.py @@ -48,10 +48,7 @@ def __init__( vertical_params: v_grid.VerticalGrid, edge_params: grid_states.EdgeParams, owner_mask: fa.CellField[bool], - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, ): self.grid: icon_grid.IconGrid = grid self.metric_state: dycore_states.MetricStateNonHydro = metric_state diff --git a/model/common/src/icon4py/model/common/model_backends.py b/model/common/src/icon4py/model/common/model_backends.py index 793bf97e3f..c4c6efaa2d 100644 --- a/model/common/src/icon4py/model/common/model_backends.py +++ b/model/common/src/icon4py/model/common/model_backends.py @@ -5,6 +5,22 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause + +""" +This module defines the available backends for ICON4Py and provides utilities to work with them. + +Backends ('BackendLike') are one of the following +- concrete GT4Py backends (instances of 'gtx_typing.Backend') +- 'None', which indicates embedded execution +- 'BackendDescriptor's, which are dictionaries containing the configuration parameters to construct a concrete backend +- 'DeviceType's, which indicate the target device for which the default backend should be constructed. +See 'model_options.customize_backend()' for details. + +Note: A concrete GT4Py backend (an instance of 'gtx_typing.Backend') is also a 'gtx_typing.Allocator' as it implements +the 'FieldBufferAllocatorFactoryProtocol', while other 'BackendLike's do not. +Use 'get_allocator' to get the allocator for any 'BackendLike'. +""" + from typing import Any, Final, TypeAlias, TypeGuard import gt4py.next as gtx @@ -35,7 +51,7 @@ def is_backend_descriptor( def get_allocator( backend: BackendLike, -) -> gtx_typing.Backend: +) -> gtx_typing.Allocator: if isinstance(backend, gtx_backend.Backend): return backend if backend is None: diff --git a/model/common/src/icon4py/model/common/model_options.py b/model/common/src/icon4py/model/common/model_options.py index f08a154667..9a4a7760ad 100644 --- a/model/common/src/icon4py/model/common/model_options.py +++ b/model/common/src/icon4py/model/common/model_options.py @@ -85,11 +85,18 @@ def get_options(program_name: str, **backend_descriptor: Any) -> model_backends. def customize_backend( program: gtx_typing.Program | gtx.typing.FieldOperator | None, - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, ) -> gtx_typing.Backend | None: + """ + Customizes the backend according to the provided 'BackendLike' and program name. + + In case a certain option is already provided in the 'BackendDesriptor', the customization + should not override it. Any option that doesn't apply to the backend factory should be ignored. + + Note: The current customization mechanism is an adhoc solution that needs better design, + e.g. program specific options should not be gathered in this general module. + """ + program_name = program.__name__ if program is not None else "" if backend is None or isinstance(backend, gtx_backend.Backend): backend_name = backend.name if backend is not None else "embedded" @@ -115,10 +122,7 @@ def customize_backend( def setup_program( program: gtx_typing.Program, - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, constant_args: dict[str, gtx.Field | gtx_typing.Scalar] | None = None, variants: dict[str, list[gtx_typing.Scalar]] | None = None, horizontal_sizes: dict[str, gtx.int32] | None = None, @@ -128,7 +132,8 @@ def setup_program( """ This function processes arguments to the GT4Py program. It - binds arguments that don't change during model run ('constant_args', 'horizontal_sizes', "vertical_sizes'); - - inlines scalar arguments into the GT4Py program at compile-time (via GT4Py's 'compile'). + - inlines scalar arguments into the GT4Py program at compile-time (via GT4Py's 'compile'); + - instantiates a concrete backend according to the provided 'BackendLike'. Args: - backend: GT4Py backend, - program: GT4Py program, diff --git a/model/common/tests/common/test_model_options.py b/model/common/tests/common/test_model_options.py index ab5ec10232..38def01ae6 100644 --- a/model/common/tests/common/test_model_options.py +++ b/model/common/tests/common/test_model_options.py @@ -68,10 +68,7 @@ def test_custom_backend_device() -> None: ], ) def test_setup_program_defaults( - backend: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend: model_backends.BackendLike, ) -> None: testee = setup_program(backend=backend, program=program_return_field) expected_backend = model_backends.make_custom_dace_backend(device=model_backends.CPU) @@ -109,10 +106,7 @@ def test_setup_program_defaults( ], ) def test_setup_program_specify_inputs( - backend_params: gtx_typing.Backend - | model_backends.DeviceType - | model_backends.BackendDescriptor - | None, + backend_params: model_backends.BackendLike, expected_backend: gtx_typing.Backend | None, ) -> None: testee = setup_program(backend=backend_params, program=program_return_field) diff --git a/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py b/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py index 276be59f83..7c99f85da5 100644 --- a/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py +++ b/model/standalone_driver/src/icon4py/model/standalone_driver/standalone_driver.py @@ -72,7 +72,7 @@ def __init__( ) @functools.cached_property - def _allocator(self) -> gtx.typing.Backend: + def _allocator(self) -> gtx.typing.Allocator: return model_backends.get_allocator(self.backend) @functools.cached_property