diff --git a/examples/notebooks/functions.py b/examples/notebooks/functions.py index 0c81ff58..f0f7d84a 100644 --- a/examples/notebooks/functions.py +++ b/examples/notebooks/functions.py @@ -106,7 +106,7 @@ def init_quantity( units=units, origin=(nhalo, nhalo, 0)[:skip_z], extent=(nx, ny, nz)[:skip_z], - gt4py_backend=backend, + backend=backend, ) if grid == VariableGrid.CellCorners: @@ -116,7 +116,7 @@ def init_quantity( units=units, origin=(nhalo, nhalo, 0)[:skip_z], extent=(nx + 1, ny + 1, nz)[:skip_z], - gt4py_backend=backend, + backend=backend, ) elif grid == VariableGrid.StaggeredInX: @@ -126,7 +126,7 @@ def init_quantity( units=units, origin=(nhalo, nhalo, 0)[:skip_z], extent=(nx + 1, ny, nz)[:skip_z], - gt4py_backend=backend, + backend=backend, ) elif grid == VariableGrid.StaggeredInY: @@ -136,7 +136,7 @@ def init_quantity( units=units, origin=(nhalo, nhalo, 0)[:skip_z], extent=(nx, ny + 1, nz)[:skip_z], - gt4py_backend=backend, + backend=backend, ) return variable diff --git a/examples/notebooks/grid_generation.ipynb b/examples/notebooks/grid_generation.ipynb index eb331536..3442d50e 100644 --- a/examples/notebooks/grid_generation.ipynb +++ b/examples/notebooks/grid_generation.ipynb @@ -267,7 +267,7 @@ " units=\"degrees\",\n", " origin=(nhalo, nhalo),\n", " extent=(nx + 1, ny + 1),\n", - " gt4py_backend=backend,\n", + " backend=backend,\n", ")\n", "lat = Quantity(\n", " metric_terms.lat.data * 180 / np.pi,\n", @@ -275,7 +275,7 @@ " units=\"degrees\",\n", " origin=(nhalo, nhalo),\n", " extent=(nx + 1, ny + 1),\n", - " gt4py_backend=backend,\n", + " backend=backend,\n", ")\n", "\n", "# gather the distributed fields into a global field on the root rank\n", @@ -357,7 +357,7 @@ " units=\"m2\",\n", " origin=(nhalo, nhalo),\n", " extent=(nx, ny),\n", - " gt4py_backend=backend,\n", + " backend=backend,\n", ")\n", "\n", "# rescale to 10^3 km2\n", diff --git a/pace/grid.py b/pace/grid.py index f91434c6..2e681334 100644 --- a/pace/grid.py +++ b/pace/grid.py @@ -181,12 +181,8 @@ def get_grid( quantity_factory: QuantityFactory, communicator: Communicator, ) -> Tuple[DampingCoefficients, DriverGridData, GridData]: - backend = quantity_factory.zeros( - dims=[X_DIM, Y_DIM], units="unknown" - ).gt4py_backend - ndsl_log.info("Using serialized grid data") - grid = self._get_serialized_grid(communicator, backend) + grid = self._get_serialized_grid(communicator, quantity_factory.backend) grid_data = grid.grid_data driver_grid_data = grid.driver_grid_data damping_coefficients = grid.damping_coefficients diff --git a/pace/initialization.py b/pace/initialization.py index b3eed492..f4eff622 100644 --- a/pace/initialization.py +++ b/pace/initialization.py @@ -286,11 +286,9 @@ def get_driver_state( grid_data: GridData, schemes: List[PHYSICS_PACKAGES], ) -> DriverState: - backend = quantity_factory.zeros( - dims=[X_DIM, Y_DIM], units="unknown" - ).gt4py_backend - - dycore_state = self._initialize_dycore_state(communicator, backend) + dycore_state = self._initialize_dycore_state( + communicator, quantity_factory.backend + ) physics_state = PhysicsState.init_zeros( quantity_factory=quantity_factory, schemes=schemes, diff --git a/tests/main/driver/test_safety_checks.py b/tests/main/driver/test_safety_checks.py index 77560b88..9eec5c73 100644 --- a/tests/main/driver/test_safety_checks.py +++ b/tests/main/driver/test_safety_checks.py @@ -65,7 +65,7 @@ def test_check_state_domain_only(): "unknown", origin=(1, 1, 0), extent=(3, 3, 2), - gt4py_backend="numpy", + backend="numpy", ) dycore_state = unittest.mock.MagicMock(u=u_quantity) checker.check_state(dycore_state) @@ -83,7 +83,7 @@ def test_check_nan_value(): "unknown", origin=(0, 0, 0), extent=(4, 4, 2), - gt4py_backend="numpy", + backend="numpy", ) dycore_state = unittest.mock.MagicMock(u=u_quantity) with pytest.raises(RuntimeError):