@@ -179,15 +179,31 @@ def encode_zarr_attr_value(value):
179179 return encoded
180180
181181
182+ def _is_coordinate_variable (zarr_array , name ):
183+ if _zarr_v3 ():
184+ if zarr_array .metadata .zarr_format == 2 :
185+ is_coordinate = name in zarr_array .metadata .attributes .get (
186+ "_ARRAY_DIMENSIONS" , []
187+ )
188+ else :
189+ is_coordinate = name in (zarr_array .metadata .dimension_names or [])
190+ else :
191+ is_coordinate = name in zarr_array .attrs .get ("_ARRAY_DIMENSIONS" , [])
192+ return is_coordinate
193+
194+
182195class ZarrArrayWrapper (BackendArray ):
183- __slots__ = ("_array" , "dtype " , "shape " , "is_coordinate" )
196+ __slots__ = ("_array" , "coords_buffer_prototype " , "dtype " , "is_coordinate" , "shape " )
184197
185- def __init__ (self , zarr_array , is_coordinate : bool ):
198+ def __init__ (
199+ self , zarr_array , is_coordinate : bool , coords_buffer_prototype : Any | None
200+ ):
186201 # some callers attempt to evaluate an array if an `array` property exists on the object.
187202 # we prefix with _ to avoid this inference.
188203 self ._array = zarr_array
189204 self .shape = self ._array .shape
190205 self .is_coordinate = is_coordinate
206+ self .coords_buffer_prototype = coords_buffer_prototype
191207
192208 # preserve vlen string object dtype (GH 7328)
193209 if (
@@ -211,12 +227,14 @@ def _vindex(self, key):
211227 return self ._array .vindex [key ]
212228
213229 def _getitem (self , key ):
214- from zarr .core .buffer .cpu import buffer_prototype
215- if self .is_coordinate :
216- prototype = buffer_prototype
217- else :
218- prototype = None
219- return self ._array .get_basic_selection (key , prototype = prototype )
230+ kwargs = {}
231+ if _zarr_v3 ():
232+ if self .is_coordinate :
233+ prototype = self .coords_buffer_prototype
234+ else :
235+ prototype = None
236+ kwargs ["prototype" ] = prototype
237+ return self ._array .get_basic_selection (key , ** kwargs )
220238
221239 def __getitem__ (self , key ):
222240 array = self ._array
@@ -611,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore):
611629 "_cache_members" ,
612630 "_close_store_on_close" ,
613631 "_consolidate_on_close" ,
632+ "_coords_buffer_prototype" ,
614633 "_group" ,
615634 "_members" ,
616635 "_mode" ,
@@ -642,6 +661,7 @@ def open_store(
642661 use_zarr_fill_value_as_mask = None ,
643662 write_empty : bool | None = None ,
644663 cache_members : bool = True ,
664+ coords_buffer_prototype : Any | None = None ,
645665 ):
646666 (
647667 zarr_group ,
@@ -674,6 +694,7 @@ def open_store(
674694 close_store_on_close ,
675695 use_zarr_fill_value_as_mask ,
676696 cache_members = cache_members ,
697+ coords_buffer_prototype = coords_buffer_prototype ,
677698 )
678699 for group in group_paths
679700 }
@@ -697,6 +718,7 @@ def open_group(
697718 use_zarr_fill_value_as_mask = None ,
698719 write_empty : bool | None = None ,
699720 cache_members : bool = True ,
721+ coords_buffer_prototype : Any | None = None ,
700722 ):
701723 (
702724 zarr_group ,
@@ -728,6 +750,7 @@ def open_group(
728750 close_store_on_close ,
729751 use_zarr_fill_value_as_mask ,
730752 cache_members ,
753+ coords_buffer_prototype ,
731754 )
732755
733756 def __init__ (
@@ -742,6 +765,7 @@ def __init__(
742765 close_store_on_close : bool = False ,
743766 use_zarr_fill_value_as_mask = None ,
744767 cache_members : bool = True ,
768+ coords_buffer_prototype : Any | None = None ,
745769 ):
746770 self .zarr_group = zarr_group
747771 self ._read_only = self .zarr_group .read_only
@@ -757,6 +781,14 @@ def __init__(
757781 self ._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask
758782 self ._cache_members : bool = cache_members
759783 self ._members : dict [str , ZarrArray | ZarrGroup ] = {}
784+ if _zarr_v3 () and coords_buffer_prototype is None :
785+ # Once zarr-v3 is required we can just have this as the default
786+ # https://github.com/zarr-developers/zarr-python/issues/2871
787+ # Use the public API once available
788+ from zarr .core .buffer .cpu import buffer_prototype
789+
790+ coords_buffer_prototype = buffer_prototype
791+ self ._coords_buffer_prototype = coords_buffer_prototype
760792
761793 if self ._cache_members :
762794 # initialize the cache
@@ -815,8 +847,15 @@ def ds(self):
815847
816848 def open_store_variable (self , name ):
817849 zarr_array = self .members [name ]
818- is_coordinate = name in zarr_array .metadata .dimension_names
819- data = indexing .LazilyIndexedArray (ZarrArrayWrapper (zarr_array , is_coordinate = is_coordinate ))
850+ is_coordinate = _is_coordinate_variable (zarr_array , name )
851+
852+ data = indexing .LazilyIndexedArray (
853+ ZarrArrayWrapper (
854+ zarr_array ,
855+ is_coordinate = is_coordinate ,
856+ coords_buffer_prototype = self ._coords_buffer_prototype ,
857+ )
858+ )
820859 try_nczarr = self ._mode == "r"
821860 dimensions , attributes = _get_zarr_dims_and_attrs (
822861 zarr_array , DIMENSION_KEY , try_nczarr
@@ -1339,6 +1378,7 @@ def open_zarr(
13391378 use_zarr_fill_value_as_mask = None ,
13401379 chunked_array_type : str | None = None ,
13411380 from_array_kwargs : dict [str , Any ] | None = None ,
1381+ coords_buffer_prototype : Any | None = None ,
13421382 ** kwargs ,
13431383):
13441384 """Load and decode a dataset from a Zarr store.
@@ -1449,6 +1489,12 @@ def open_zarr(
14491489 chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg.
14501490 Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to
14511491 :py:func:`dask.array.from_array`. Experimental API that should not be relied upon.
1492+ coords_buffer_prototype : zarr.buffer.BufferPrototype, optional
1493+ The buffer prototype to use for loading coordinate arrays. Zarr offers control over
1494+ which device's memory buffers are read into. By default, xarray will always load
1495+ *coordinate* buffers into host (CPU) memory, regardless of the global zarr
1496+ configuration. To override this behavior, explicitly pass the buffer prototype
1497+ to use for coordinates here.
14521498
14531499 Returns
14541500 -------
@@ -1492,6 +1538,7 @@ def open_zarr(
14921538 "storage_options" : storage_options ,
14931539 "zarr_version" : zarr_version ,
14941540 "zarr_format" : zarr_format ,
1541+ "coords_buffer_prototype" : coords_buffer_prototype ,
14951542 }
14961543
14971544 ds = open_dataset (
@@ -1564,6 +1611,7 @@ def open_dataset(
15641611 engine = None ,
15651612 use_zarr_fill_value_as_mask = None ,
15661613 cache_members : bool = True ,
1614+ coords_buffer_prototype : Any | None = None ,
15671615 ) -> Dataset :
15681616 filename_or_obj = _normalize_path (filename_or_obj )
15691617 if not store :
@@ -1580,6 +1628,7 @@ def open_dataset(
15801628 use_zarr_fill_value_as_mask = None ,
15811629 zarr_format = zarr_format ,
15821630 cache_members = cache_members ,
1631+ coords_buffer_prototype = coords_buffer_prototype ,
15831632 )
15841633
15851634 store_entrypoint = StoreBackendEntrypoint ()
@@ -1615,6 +1664,7 @@ def open_datatree(
16151664 storage_options = None ,
16161665 zarr_version = None ,
16171666 zarr_format = None ,
1667+ coords_buffer_prototype : Any | None = None ,
16181668 ) -> DataTree :
16191669 filename_or_obj = _normalize_path (filename_or_obj )
16201670 groups_dict = self .open_groups_as_dict (
@@ -1634,6 +1684,7 @@ def open_datatree(
16341684 storage_options = storage_options ,
16351685 zarr_version = zarr_version ,
16361686 zarr_format = zarr_format ,
1687+ coords_buffer_prototype = coords_buffer_prototype ,
16371688 )
16381689
16391690 return datatree_from_dict_with_io_cleanup (groups_dict )
@@ -1657,6 +1708,7 @@ def open_groups_as_dict(
16571708 storage_options = None ,
16581709 zarr_version = None ,
16591710 zarr_format = None ,
1711+ coords_buffer_prototype : Any | None = None ,
16601712 ) -> dict [str , Dataset ]:
16611713 from xarray .core .treenode import NodePath
16621714
@@ -1679,6 +1731,7 @@ def open_groups_as_dict(
16791731 storage_options = storage_options ,
16801732 zarr_version = zarr_version ,
16811733 zarr_format = zarr_format ,
1734+ coords_buffer_prototype = coords_buffer_prototype ,
16821735 )
16831736
16841737 groups_dict = {}
0 commit comments