Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions examples/NDSL/01_gt4py_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@
"shape = (nx, ny, nz)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)"
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)"
]
},
{
Expand Down Expand Up @@ -199,7 +199,7 @@
"outputs": [],
"source": [
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"print(\"Plotting values of qty_in at K = 0\")\n",
Expand All @@ -212,7 +212,7 @@
"qty_out.plot_k_level(0)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"print(\"Resetting qty_out to zero...\")\n",
Expand All @@ -224,7 +224,7 @@
"qty_out.plot_k_level(0)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"print(\"Resetting qty_out to zero...\")\n",
Expand All @@ -238,7 +238,7 @@
"qty_out.plot_k_level(1)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"print(\"Resetting qty_out to zero...\")\n",
"print(\"Plotting values of qty_in at K = 0\")\n",
Expand All @@ -251,7 +251,7 @@
"qty_out.plot_k_level(0)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros([nx, ny, nz]), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"print(\"Resetting qty_out to zero...\")\n",
"print(\"Plotting values of qty_out at K = 0\")\n",
Expand Down Expand Up @@ -294,13 +294,13 @@
"shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)\n",
"\n",
"print(\"Plotting values of qty_in at K = 0\")\n",
"qty_in.plot_k_level(0)\n",
Expand Down Expand Up @@ -344,7 +344,7 @@
"\n",
"print(\"Resetting qty_out to zeros\")\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"print(\"Executing 'copy_downward' with origin=(1, 1, 0), domain=(nx, ny, nz-1)\")\n",
Expand Down Expand Up @@ -401,13 +401,13 @@
"shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)\n",
"\n",
"\n",
"@stencil(backend=backend)\n",
Expand Down Expand Up @@ -444,13 +444,13 @@
"outputs": [],
"source": [
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)\n",
"\n",
"print(\"Plotting values of qty_in at K = 0\")\n",
"qty_in.plot_k_level(0)\n",
Expand Down Expand Up @@ -525,13 +525,13 @@
"shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)\n",
"\n",
"print(\"Plotting values of qty_in at K = 0\")\n",
"qty_in.plot_k_level(0)\n",
Expand All @@ -546,7 +546,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand Down
6 changes: 3 additions & 3 deletions examples/NDSL/02_NDSL_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"shape = (nx + 2 * nhalo, ny + 2 * nhalo, nz)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"\n",
"arr = np.indices(shape, dtype=float).sum(\n",
" axis=0\n",
") # Value of each entry is sum of the I and J index at each point\n",
"\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend)\n",
"qty_in = Quantity(data=arr, dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend)\n",
"\n",
"print(\"Plotting qty_in at K = 0\")\n",
"qty_in.plot_k_level(0)\n",
Expand Down Expand Up @@ -224,7 +224,7 @@
"copy_field_offset = CopyFieldOffset(stencil_factory)\n",
"\n",
"qty_out = Quantity(\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", gt4py_backend=backend\n",
" data=np.zeros(shape), dims=[\"I\", \"J\", \"K\"], units=\"m\", backend=backend\n",
")\n",
"\n",
"print(\"Initialize qty_out to zeros\")"
Expand Down
20 changes: 1 addition & 19 deletions ndsl/quantity/local.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from collections.abc import Sequence
from typing import Any

Expand All @@ -23,28 +22,11 @@ def __init__(
dims: Sequence[str],
units: str,
*,
backend: str | None = None,
backend: str,
origin: Sequence[int] | None = None,
extent: Sequence[int] | None = None,
gt4py_backend: str | None = None,
allow_mismatch_float_precision: bool = False,
):
if gt4py_backend is not None:
warnings.warn(
"gt4py_backend is deprecated. Use `backend` instead.",
DeprecationWarning,
stacklevel=2,
)
if backend is None:
backend = gt4py_backend

if backend is None:
warnings.warn(
"`backend` will be a required argument starting with the next version of NDSL.",
DeprecationWarning,
stacklevel=2,
)

# Initialize memory to obviously wrong value - Local should _not_ be expected
# to be zero'ed.
data[:] = 123456789
Expand Down
34 changes: 16 additions & 18 deletions ndsl/quantity/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,39 @@
@dataclasses.dataclass
class QuantityMetadata:
origin: tuple[int, ...]
"the start of the computational domain"
"The start of the computational domain."
extent: tuple[int, ...]
"the shape of the computational domain"
"The shape of the computational domain."
n_halo: int
"Number of halo-points used in the horizontal"
"Number of halo-points used in the horizontal."
dims: tuple[str, ...]
"names of each dimension"
"Names of each dimension."
units: str
"units of the quantity"
"Units of the quantity."
data_type: type
"ndarray-like type used to store the data"
"ndarray-like type used to store the data."
dtype: type
"dtype of the data in the ndarray-like object"
gt4py_backend: str | None = None
"Deprecated. Use backend instead."
backend: str | None = None
"dtype of the data in the ndarray-like object."
backend: str
"GT4Py backend name. Used for performance optimal data allocation."

@property
def dim_lengths(self) -> dict[str, int]:
"""mapping of dimension names to their lengths"""
"""Mapping of dimension names to their lengths."""
return dict(zip(self.dims, self.extent))

@property
def np(self) -> NumpyModule:
"""numpy-like module used to interact with the data"""
"""numpy-like module used to interact with the data."""
if issubclass(self.data_type, cupy.ndarray):
return cupy
elif issubclass(self.data_type, np.ndarray):

if issubclass(self.data_type, np.ndarray):
return np
else:
raise TypeError(
f"quantity underlying data is of unexpected type {self.data_type}"
)

raise TypeError(
f"Quantity underlying data is of unexpected type {self.data_type}"
)

def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None:
metadata_copy.origin = self.origin
Expand All @@ -58,7 +57,6 @@ def duplicate_metadata(self, metadata_copy: QuantityMetadata) -> None:
metadata_copy.units = self.units
metadata_copy.data_type = self.data_type
metadata_copy.dtype = self.dtype
metadata_copy.gt4py_backend = self.gt4py_backend
metadata_copy.backend = self.backend


Expand Down
Loading