Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Force Zarr coordinate reads to be on the host #10079

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,31 @@ def encode_zarr_attr_value(value):
return encoded


def _is_coordinate_variable(zarr_array, name):
if _zarr_v3():
if zarr_array.metadata.zarr_format == 2:
is_coordinate = name in zarr_array.metadata.attributes.get(
"_ARRAY_DIMENSIONS", []
)
else:
is_coordinate = name in (zarr_array.metadata.dimension_names or [])
else:
is_coordinate = name in zarr_array.attrs.get("_ARRAY_DIMENSIONS", [])
return is_coordinate


class ZarrArrayWrapper(BackendArray):
__slots__ = ("_array", "dtype", "shape")
__slots__ = ("_array", "coords_buffer_prototype", "dtype", "is_coordinate", "shape")

def __init__(self, zarr_array):
def __init__(
self, zarr_array, is_coordinate: bool, coords_buffer_prototype: Any | None
):
# some callers attempt to evaluate an array if an `array` property exists on the object.
# we prefix with _ to avoid this inference.
self._array = zarr_array
self.shape = self._array.shape
self.is_coordinate = is_coordinate
self.coords_buffer_prototype = coords_buffer_prototype

# preserve vlen string object dtype (GH 7328)
if (
Expand All @@ -210,7 +227,14 @@ def _vindex(self, key):
return self._array.vindex[key]

def _getitem(self, key):
return self._array[key]
kwargs = {}
if _zarr_v3():
if self.is_coordinate:
prototype = self.coords_buffer_prototype
else:
prototype = None
kwargs["prototype"] = prototype
return self._array.get_basic_selection(key, **kwargs)

def __getitem__(self, key):
array = self._array
Expand Down Expand Up @@ -605,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore):
"_cache_members",
"_close_store_on_close",
"_consolidate_on_close",
"_coords_buffer_prototype",
"_group",
"_members",
"_mode",
Expand Down Expand Up @@ -636,6 +661,7 @@ def open_store(
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
(
zarr_group,
Expand Down Expand Up @@ -668,6 +694,7 @@ def open_store(
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members=cache_members,
coords_buffer_prototype=coords_buffer_prototype,
)
for group in group_paths
}
Expand All @@ -691,6 +718,7 @@ def open_group(
use_zarr_fill_value_as_mask=None,
write_empty: bool | None = None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
(
zarr_group,
Expand Down Expand Up @@ -722,6 +750,7 @@ def open_group(
close_store_on_close,
use_zarr_fill_value_as_mask,
cache_members,
coords_buffer_prototype,
)

def __init__(
Expand All @@ -736,6 +765,7 @@ def __init__(
close_store_on_close: bool = False,
use_zarr_fill_value_as_mask=None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
):
self.zarr_group = zarr_group
self._read_only = self.zarr_group.read_only
Expand All @@ -751,6 +781,14 @@ def __init__(
self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
self._cache_members: bool = cache_members
self._members: dict[str, ZarrArray | ZarrGroup] = {}
if _zarr_v3() and coords_buffer_prototype is None:
# Once zarr-v3 is required we can just have this as the default
# https://github.com/zarr-developers/zarr-python/issues/2871
# Use the public API once available
from zarr.core.buffer.cpu import buffer_prototype

coords_buffer_prototype = buffer_prototype
self._coords_buffer_prototype = coords_buffer_prototype

if self._cache_members:
# initialize the cache
Expand Down Expand Up @@ -809,7 +847,15 @@ def ds(self):

def open_store_variable(self, name):
zarr_array = self.members[name]
data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array))
is_coordinate = _is_coordinate_variable(zarr_array, name)

data = indexing.LazilyIndexedArray(
ZarrArrayWrapper(
zarr_array,
is_coordinate=is_coordinate,
coords_buffer_prototype=self._coords_buffer_prototype,
)
)
try_nczarr = self._mode == "r"
dimensions, attributes = _get_zarr_dims_and_attrs(
zarr_array, DIMENSION_KEY, try_nczarr
Expand Down Expand Up @@ -1332,6 +1378,7 @@ def open_zarr(
use_zarr_fill_value_as_mask=None,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
coords_buffer_prototype: Any | None = None,
**kwargs,
):
"""Load and decode a dataset from a Zarr store.
Expand Down Expand Up @@ -1442,6 +1489,12 @@ def open_zarr(
chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
:py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
coords_buffer_prototype : zarr.buffer.BufferPrototype, optional
The buffer prototype to use for loading coordinate arrays. Zarr offers control over
which device's memory buffers are read into. By default, xarray will always load
*coordinate* buffers into host (CPU) memory, regardless of the global zarr
configuration. To override this behavior, explicitly pass the buffer prototype
to use for coordinates here.

Returns
-------
Expand Down Expand Up @@ -1485,6 +1538,7 @@ def open_zarr(
"storage_options": storage_options,
"zarr_version": zarr_version,
"zarr_format": zarr_format,
"coords_buffer_prototype": coords_buffer_prototype,
}

ds = open_dataset(
Expand Down Expand Up @@ -1557,6 +1611,7 @@ def open_dataset(
engine=None,
use_zarr_fill_value_as_mask=None,
cache_members: bool = True,
coords_buffer_prototype: Any | None = None,
) -> Dataset:
filename_or_obj = _normalize_path(filename_or_obj)
if not store:
Expand All @@ -1573,6 +1628,7 @@ def open_dataset(
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
cache_members=cache_members,
coords_buffer_prototype=coords_buffer_prototype,
)

store_entrypoint = StoreBackendEntrypoint()
Expand Down Expand Up @@ -1608,6 +1664,7 @@ def open_datatree(
storage_options=None,
zarr_version=None,
zarr_format=None,
coords_buffer_prototype: Any | None = None,
) -> DataTree:
filename_or_obj = _normalize_path(filename_or_obj)
groups_dict = self.open_groups_as_dict(
Expand All @@ -1627,6 +1684,7 @@ def open_datatree(
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
coords_buffer_prototype=coords_buffer_prototype,
)

return datatree_from_dict_with_io_cleanup(groups_dict)
Expand All @@ -1650,6 +1708,7 @@ def open_groups_as_dict(
storage_options=None,
zarr_version=None,
zarr_format=None,
coords_buffer_prototype: Any | None = None,
) -> dict[str, Dataset]:
from xarray.core.treenode import NodePath

Expand All @@ -1672,6 +1731,7 @@ def open_groups_as_dict(
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
coords_buffer_prototype=coords_buffer_prototype,
)

groups_dict = {}
Expand Down
31 changes: 31 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -3766,6 +3766,37 @@
xr.open_zarr(store=store, zarr_version=2, zarr_format=3)


@requires_zarr
def test_coords_buffer_prototype() -> None:
pytest.importorskip("zarr", minversion="3")

from zarr.core.buffer import cpu
from zarr.core.buffer.core import BufferPrototype

counter = 0

class Buffer(cpu.Buffer):
def __init__(self, *args, **kwargs):
nonlocal counter
counter += 1
super().__init__(*args, **kwargs)

class NDBuffer(cpu.NDBuffer):
def __init__(self, *args, **kwargs):
nonlocal counter
counter += 1
super().__init__(*args, **kwargs)

prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)

ds = create_test_data()
store = KVStore()
# type-ignore for zarr v2/v3 compat, even though this test is skipped for v2
ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload, unused-ignore]
xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type, unused-ignore]
assert counter > 0


@requires_scipy
class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only):
engine: T_NetcdfEngine = "scipy"
Expand Down Expand Up @@ -4185,7 +4216,7 @@
fx.create_dataset(k, data=v)
with pytest.warns(UserWarning, match="The 'phony_dims' kwarg"):
with xr.open_dataset(tmp_file, engine="h5netcdf", group="bar") as ds:
assert ds.dims == {

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / ubuntu-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.

Check warning on line 4219 in xarray/tests/test_backends.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

The return type of `Dataset.dims` will be changed to return a set of dimension names in future, in order to be more consistent with `DataArray.dims`. To access a mapping from dimension names to lengths, please use `Dataset.sizes`.
"phony_dim_0": 5,
"phony_dim_1": 5,
"phony_dim_2": 5,
Expand Down
Loading