diff --git a/intake_xarray/base.py b/intake_xarray/base.py index 63cf85a..8595baa 100644 --- a/intake_xarray/base.py +++ b/intake_xarray/base.py @@ -1,3 +1,7 @@ +import json + +import dask.array + from . import __version__ from intake.source.base import DataSource, Schema @@ -24,7 +28,14 @@ def _get_schema(self): 'coords': tuple(self._ds.coords.keys()), } if getattr(self, 'on_server', False): - metadata['internal'] = serialize_zarr_ds(self._ds) + serialized = serialize_zarr_ds(self._ds) + metadata['internal'] = serialized + # The zarr serialization imposes a certain chunking, which will + # be reflected in the xarray.Dataset object constructed on the + # client side. We need to use that same chunking here on the + # server side. Extract it from the serialized zarr metadata. + self._chunks = {k.rsplit('/', 1)[0]: json.loads(v.decode())['chunks'] + for k, v in serialized.items() if k.endswith('/.zarray')} metadata.update(self._ds.attrs) self._schema = Schema( datashape=None, @@ -52,17 +63,16 @@ def read_partition(self, i): if not isinstance(i, (tuple, list)): raise TypeError('For Xarray sources, must specify partition as ' 'tuple') - if isinstance(i, list): - i = tuple(i) - if hasattr(self._ds, 'variables') or i[0] in self._ds.coords: - arr = self._ds[i[0]].data - i = i[1:] + variable, *part = i + part = tuple(part) + if hasattr(self._ds, 'variables') or variable in self._ds.coords: + arr = self._ds[variable].data else: arr = self._ds.data if isinstance(arr, np.ndarray): - return arr - # dask array - return arr.blocks[i].compute() + # Make a dask.array so that we can return the appropriate block. + arr = dask.array.from_array(arr, chunks=self._chunks[variable]) + return arr.blocks[part].compute() def to_dask(self): """Return xarray object where variables are dask arrays"""