diff --git a/examples/notebook/test_physics.ipynb b/examples/notebook/test_physics.ipynb index a233699..e69dac9 100644 --- a/examples/notebook/test_physics.ipynb +++ b/examples/notebook/test_physics.ipynb @@ -122,6 +122,7 @@ " layout=layout,\n", " tile_partitioner=partitioner.tile,\n", " tile_rank=cs_communicator.tile.rank,\n", + " backend=backend,\n", ")\n", "\n", "# useful for easily allocating distributed data storages (fields)\n", diff --git a/examples/notebook/test_rad.ipynb b/examples/notebook/test_rad.ipynb index 1f31ace..0fa4797 100644 --- a/examples/notebook/test_rad.ipynb +++ b/examples/notebook/test_rad.ipynb @@ -88,6 +88,7 @@ "outputs": [], "source": [ "rank = 0\n", + "backend = \"numpy\"\n", "\n", "comm = NullComm(rank, 1)\n", "communicator = TileCommunicator.from_layout(comm=comm, layout=(1,1))\n", @@ -101,8 +102,9 @@ " layout=(1,1),\n", " tile_partitioner=communicator.partitioner.tile,\n", " tile_rank=communicator.tile.rank,\n", + " backend=backend\n", ")\n", - "quantity_factory = QuantityFactory(sizer, backend=\"numpy\")" + "quantity_factory = QuantityFactory(sizer, backend=backend)" ] }, { diff --git a/examples/notebook/test_raddriver.ipynb b/examples/notebook/test_raddriver.ipynb index a05dc52..9caa47e 100644 --- a/examples/notebook/test_raddriver.ipynb +++ b/examples/notebook/test_raddriver.ipynb @@ -74,6 +74,7 @@ "outputs": [], "source": [ "rank = 0\n", + "backend = \"numpy\"\n", "\n", "comm = NullComm(rank, 1)\n", "communicator = TileCommunicator.from_layout(comm=comm, layout=(1,1))\n", @@ -87,8 +88,9 @@ " layout=(1,1),\n", " tile_partitioner=communicator.partitioner.tile,\n", " tile_rank=communicator.tile.rank,\n", + " backend=backend\n", ")\n", - "quantity_factory = QuantityFactory(sizer, backend=\"numpy\")" + "quantity_factory = QuantityFactory(sizer, backend=backend)" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 0137f59..1243405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ extras = [ "pyshield[pyfv3]", "pyshield[dev]" ] -ndsl = ["ndsl @ git+https://github.com/NOAA-GFDL/NDSL.git@develop"] +ndsl = ["ndsl @ git+https://github.com/NOAA-GFDL/NDSL.git@2026.01.00"] pyfv3 = ["pyfv3 @ git+https://github.com/NOAA-GFDL/PyFV3.git@develop"] test = [ "pytest", diff --git a/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_mp_driver.py b/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_mp_driver.py index a2ac5c9..0071718 100644 --- a/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_mp_driver.py +++ b/pyshield/stencils/gfdl_cld_microphysics/gfdl_cld_mp_driver.py @@ -899,7 +899,7 @@ def make_quantity2d(**kwargs): self._convert_mm_day = 86400.0 * constants.RGRAV / self.config.dt_split self._copy_stencil = stencil_factory.from_origin_domain( - basic.copy_defn, + basic.copy, origin=self._idx.origin_compute(), domain=self._idx.domain_compute(), ) diff --git a/pyshield/stencils/gfdl_cld_microphysics/terminal_fall.py b/pyshield/stencils/gfdl_cld_microphysics/terminal_fall.py index eb288e0..cc0fa1c 100644 --- a/pyshield/stencils/gfdl_cld_microphysics/terminal_fall.py +++ b/pyshield/stencils/gfdl_cld_microphysics/terminal_fall.py @@ -7,7 +7,7 @@ from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.gt4py import BACKWARD, FORWARD, PARALLEL, computation, interval from ndsl.dsl.typing import FloatField, FloatFieldIJ, IntFieldIJ -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy from pyfv3.stencils.remap_profile import RemapProfile from pyshield.stencils.gfdl_cld_microphysics._config import GFDLCloudMPConfig @@ -407,7 +407,7 @@ def __init__( # compile stencils self._copy_stencil = stencil_factory.from_dims_halo( - copy_defn, + copy, compute_dims=dims, ) diff --git a/pyshield/stencils/physics.py b/pyshield/stencils/physics.py index fd2a95b..185cfc4 100644 --- a/pyshield/stencils/physics.py +++ b/pyshield/stencils/physics.py @@ -21,7 +21,7 @@ ) from ndsl.grid import GridData from ndsl.logging import ndsl_log -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy from pyshield._config import ( PHYSICS_PACKAGES, TRACER_DIM, @@ -1242,7 +1242,7 @@ def make_quantity_2d(): self._sfcvisdfd = make_quantity_2d() self._copy_stencil = stencil_factory.from_origin_domain( - func=copy_defn, + func=copy, origin=grid_indexing.origin_full(), domain=grid_indexing.domain_full(add=(0, 0, 1)), ) diff --git a/tests/integration/test_sfc.py b/tests/integration/test_sfc.py index a2bc105..43005a9 100644 --- a/tests/integration/test_sfc.py +++ b/tests/integration/test_sfc.py @@ -112,8 +112,8 @@ def states_from_fortran_restarts( def setup_infrastructure(nx: Int, ny: Int, nz: Int, nzsoil: Int, etafile: Path): n_halo = 3 - rank = 0 + backend = "numpy" comm = NullComm(rank, 1) communicator = TileCommunicator.from_layout(comm=comm, layout=(1, 1)) @@ -127,8 +127,9 @@ def setup_infrastructure(nx: Int, ny: Int, nz: Int, nzsoil: Int, etafile: Path): layout=(1, 1), tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, + backend=backend, ) - quantity_factory = QuantityFactory(sizer, backend="numpy") + quantity_factory = QuantityFactory(sizer, backend=backend) soil_sizer = SubtileGridSizer.from_tile_params( nx_tile=nx, @@ -139,6 +140,7 @@ def setup_infrastructure(nx: Int, ny: Int, nz: Int, nzsoil: Int, etafile: Path): layout=(1, 1), tile_partitioner=communicator.partitioner.tile, tile_rank=communicator.tile.rank, + backend=backend, ) qf_soil = QuantityFactory(soil_sizer, backend="numpy") diff --git a/tests/savepoint/translate/translate_atmos_phy_statein.py b/tests/savepoint/translate/translate_atmos_phy_statein.py index cb03106..d8f1aa5 100644 --- a/tests/savepoint/translate/translate_atmos_phy_statein.py +++ b/tests/savepoint/translate/translate_atmos_phy_statein.py @@ -58,6 +58,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_cloud_frac.py b/tests/savepoint/translate/translate_cloud_frac.py index 4dbcb43..41c30f2 100644 --- a/tests/savepoint/translate/translate_cloud_frac.py +++ b/tests/savepoint/translate/translate_cloud_frac.py @@ -344,6 +344,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_cumulative_shalconv.py b/tests/savepoint/translate/translate_cumulative_shalconv.py index 3f6aa54..2c3ad8d 100644 --- a/tests/savepoint/translate/translate_cumulative_shalconv.py +++ b/tests/savepoint/translate/translate_cumulative_shalconv.py @@ -3403,6 +3403,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -3516,6 +3517,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -3621,6 +3623,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -3738,6 +3741,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -3851,6 +3855,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -4018,6 +4023,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -4183,6 +4189,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -4318,6 +4325,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -4379,6 +4387,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -4533,6 +4542,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_final_mp.py b/tests/savepoint/translate/translate_final_mp.py index 38b2d8d..e7614aa 100644 --- a/tests/savepoint/translate/translate_final_mp.py +++ b/tests/savepoint/translate/translate_final_mp.py @@ -1034,6 +1034,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -1491,6 +1492,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_gfdl_cld_microphysics.py b/tests/savepoint/translate/translate_gfdl_cld_microphysics.py index cdaa9ec..7715851 100644 --- a/tests/savepoint/translate/translate_gfdl_cld_microphysics.py +++ b/tests/savepoint/translate/translate_gfdl_cld_microphysics.py @@ -328,6 +328,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_gfs_physics_driver.py b/tests/savepoint/translate/translate_gfs_physics_driver.py index e5f62d0..4adac34 100644 --- a/tests/savepoint/translate/translate_gfs_physics_driver.py +++ b/tests/savepoint/translate/translate_gfs_physics_driver.py @@ -127,6 +127,7 @@ def compute(self, inputs): nz=self.config.npz, n_halo=3, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_mfpblt.py b/tests/savepoint/translate/translate_mfpblt.py index 81b38b5..b1d2260 100644 --- a/tests/savepoint/translate/translate_mfpblt.py +++ b/tests/savepoint/translate/translate_mfpblt.py @@ -62,6 +62,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_mfscu.py b/tests/savepoint/translate/translate_mfscu.py index d040de8..27ce972 100644 --- a/tests/savepoint/translate/translate_mfscu.py +++ b/tests/savepoint/translate/translate_mfscu.py @@ -65,6 +65,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_microphysics.py b/tests/savepoint/translate/translate_microphysics.py index ffa8601..922d15b 100644 --- a/tests/savepoint/translate/translate_microphysics.py +++ b/tests/savepoint/translate/translate_microphysics.py @@ -79,6 +79,7 @@ def compute(self, inputs): nz=self.config.npz, n_halo=3, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_mp_full.py b/tests/savepoint/translate/translate_mp_full.py index 037c58d..be45f57 100644 --- a/tests/savepoint/translate/translate_mp_full.py +++ b/tests/savepoint/translate/translate_mp_full.py @@ -466,6 +466,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -695,6 +696,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_particle_properties.py b/tests/savepoint/translate/translate_particle_properties.py index 2a57706..04561bf 100644 --- a/tests/savepoint/translate/translate_particle_properties.py +++ b/tests/savepoint/translate/translate_particle_properties.py @@ -500,6 +500,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_pbl.py b/tests/savepoint/translate/translate_pbl.py index 5b26d03..7f96ab9 100644 --- a/tests/savepoint/translate/translate_pbl.py +++ b/tests/savepoint/translate/translate_pbl.py @@ -97,6 +97,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_pbl_subtests.py b/tests/savepoint/translate/translate_pbl_subtests.py index fc60294..d9d9363 100644 --- a/tests/savepoint/translate/translate_pbl_subtests.py +++ b/tests/savepoint/translate/translate_pbl_subtests.py @@ -12,7 +12,7 @@ Int, IntFieldIJ, ) -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy from pyshield._config import TRACER_DIM, FloatFieldTracer from pyshield.stencils.pbl import PBLConfig from pyshield.stencils.pbl.mfpblt import PBLMassFlux @@ -1345,7 +1345,7 @@ def __init__( ) self._copy_stencil = stencil_factory.from_origin_domain( - func=copy_defn, + func=copy, origin=idx.origin_compute(), domain=idx.domain_compute(), ) @@ -3329,6 +3329,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3437,6 +3438,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3504,6 +3506,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3560,6 +3563,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3630,6 +3634,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3703,6 +3708,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3761,6 +3767,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3909,6 +3916,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -3970,6 +3978,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4064,6 +4073,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4146,6 +4156,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4212,6 +4223,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4295,6 +4307,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4383,6 +4396,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4604,6 +4618,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -4845,6 +4860,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_preliminary_mp.py b/tests/savepoint/translate/translate_preliminary_mp.py index 609e863..71d7d65 100644 --- a/tests/savepoint/translate/translate_preliminary_mp.py +++ b/tests/savepoint/translate/translate_preliminary_mp.py @@ -39,7 +39,7 @@ def make_quantity2d(**kwargs): self._bottom_density = make_quantity2d() self._copy_stencil = stencil_factory.from_origin_domain( - basic.copy_defn, + basic.copy, origin=self._idx.origin_compute(), domain=self._idx.domain_compute(), ) @@ -572,6 +572,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_samfshalconv.py b/tests/savepoint/translate/translate_samfshalconv.py index 20e7add..ae3eb59 100644 --- a/tests/savepoint/translate/translate_samfshalconv.py +++ b/tests/savepoint/translate/translate_samfshalconv.py @@ -79,6 +79,7 @@ def __init__(self, grid, config, stencil_factory): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -95,6 +96,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) diff --git a/tests/savepoint/translate/translate_sedimentation.py b/tests/savepoint/translate/translate_sedimentation.py index e74343e..5867bf4 100644 --- a/tests/savepoint/translate/translate_sedimentation.py +++ b/tests/savepoint/translate/translate_sedimentation.py @@ -695,6 +695,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( @@ -875,6 +876,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_terminal_fall.py b/tests/savepoint/translate/translate_terminal_fall.py index e68fbbb..e7f8157 100644 --- a/tests/savepoint/translate/translate_terminal_fall.py +++ b/tests/savepoint/translate/translate_terminal_fall.py @@ -135,6 +135,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_tracer_sedi.py b/tests/savepoint/translate/translate_tracer_sedi.py index 9a45071..c27b087 100644 --- a/tests/savepoint/translate/translate_tracer_sedi.py +++ b/tests/savepoint/translate/translate_tracer_sedi.py @@ -718,6 +718,7 @@ def __init__( n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) self.quantity_factory = QuantityFactory( diff --git a/tests/savepoint/translate/translate_tridiag.py b/tests/savepoint/translate/translate_tridiag.py index c808420..a1864da 100644 --- a/tests/savepoint/translate/translate_tridiag.py +++ b/tests/savepoint/translate/translate_tridiag.py @@ -3,7 +3,7 @@ from ndsl import QuantityFactory, StencilFactory, SubtileGridSizer from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float -from ndsl.stencils.basic_operations import copy_defn +from ndsl.stencils.basic_operations import copy from pyshield._config import TRACER_DIM, FloatFieldTracer from pyshield.stencils.pbl.tridiag import tridi2, tridin, tridit from tests.savepoint.translate.translate_physics import TranslatePhysicsFortranData2Py @@ -40,7 +40,7 @@ def __init__( domain=idx.domain_compute(), ) self._copy_stencil = stencil_factory.from_origin_domain( - func=copy_defn, + func=copy, origin=idx.origin_compute(), domain=idx.domain_compute(), ) @@ -83,7 +83,7 @@ def __init__( ) self._copy_stencil = stencil_factory.from_origin_domain( - func=copy_defn, + func=copy, origin=idx.origin_compute(), domain=idx.domain_compute(), ) @@ -156,7 +156,7 @@ def __init__( ) self._copy_stencil = stencil_factory.from_origin_domain( - func=copy_defn, + func=copy, origin=idx.origin_compute(), domain=idx.domain_compute(), ) @@ -227,6 +227,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) @@ -268,6 +269,7 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) compute_func = Tridi2(self.stencil_factory, quantity_factory) @@ -307,9 +309,9 @@ def compute(self, inputs): n_halo=3, data_dimensions={}, layout=self.config.layout, + backend=self.stencil_factory.backend, ) quantity_factory = QuantityFactory(sizer, backend=self.stencil_factory.backend) - config = self.config.pbl compute_func = TridiN(self.stencil_factory, quantity_factory, 8) compute_func(**inputs)