From 23b816ee69480e622368a70af7aa20aa3a068fa3 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 19 Aug 2025 11:55:04 +0000 Subject: [PATCH 1/7] feat: remove `vstack` --- arrayloaders/io/zarr_loader.py | 98 +++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 38 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index b63a404..66f0416 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -59,6 +59,29 @@ async def index_datasets( return await asyncio.gather(*tasks) +def reindex_against_integer_indices( + indices: np.ndarray, chunks: list[InMemoryArray] +) -> tuple[np.ndarray, list[InMemoryArray]]: + upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) + lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]]) + reindexed, chunks_reindexed = list( + zip( + *( + (reindexed, c[reindexed - lower]) + for c, upper, lower in zip( + chunks, upper_bounds, lower_bounds, strict=False + ) + if (reindexed := indices[(indices < upper) & (indices >= lower)]).shape[ + 0 + ] + > 0 + ), + strict=False, + ) + ) + return np.concatenate(reindexed), list(chunks_reindexed) + + add_dataset_docstring = """\ Append datasets to this loader. @@ -263,7 +286,7 @@ def iter( ) # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 # we must keep track of the leftover data. - in_memory_data = None + chunks: list[InMemoryArray] = [] in_memory_labels = None in_memory_indices = None for chunk_indices in _batched( @@ -278,9 +301,7 @@ def iter( ] dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) # Fetch the data over slices - chunks: list[InMemoryArray] = zsync.sync( - index_datasets(dataset_index_to_slices, fetch_data) - ) + chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) # Accumulate labels labels: None | list[np.ndarray] = None if self.labels is not None: @@ -316,11 +337,10 @@ def iter( for index in dataset_indices ] # Do batch returns, handling leftover data as necessary - mod = sp if isinstance(chunks[0], sp.csr_matrix) else np - in_memory_data = ( - mod.vstack(chunks) - if in_memory_data is None - else mod.vstack([in_memory_data, *chunks]) + vstack = ( + sp.vstack + if isinstance(chunks[0], sp.csr_matrix | sp.csr_array) + else np.vstack ) if self.labels is not None: in_memory_labels = ( @@ -337,39 +357,41 @@ def iter( # Create random indices into in_memory_data and then index into it # If there is "leftover" at the end (see the modulo op), # save it for the next iteration. - batch_indices = np.arange(in_memory_data.shape[0]) - if shuffle: - np.random.default_rng().shuffle(batch_indices) - splits = split_given_size(batch_indices, self._batch_size) - for i, s in enumerate(splits): - if s.shape[0] == self._batch_size: - res = [ - in_memory_data[s], - in_memory_labels[s] if self.labels is not None else None, - ] - if self._return_index: - res += [in_memory_indices[s]] - yield tuple(res) - if i == ( - len(splits) - 1 - ): # end of iteration, leftover data needs be kept - if (s.shape[0] % self._batch_size) != 0: - in_memory_data = in_memory_data[s] - if in_memory_labels is not None: - in_memory_labels = in_memory_labels[s] - if in_memory_indices is not None: - in_memory_indices = in_memory_indices[s] - else: - in_memory_data = None - in_memory_labels = None - in_memory_indices = None - if in_memory_data is not None: # handle any leftover data + if self._batch_size != (num_obs := sum(c.shape[0] for c in chunks)): + batch_indices = np.arange(num_obs) + if shuffle: + np.random.default_rng().shuffle(batch_indices) + splits = split_given_size(batch_indices, self._batch_size) + for i, s in enumerate(splits): + s, chunks_reindexed = reindex_against_integer_indices(s, chunks) + if s.shape[0] == self._batch_size: + res = [ + vstack(chunks_reindexed), + in_memory_labels[s] if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices[s]] + yield tuple(res) + if i == ( + len(splits) - 1 + ): # end of iteration, leftover data needs be kept + if (s.shape[0] % self._batch_size) != 0: + chunks = chunks_reindexed + if in_memory_labels is not None: + in_memory_labels = in_memory_labels[s] + if in_memory_indices is not None: + in_memory_indices = in_memory_indices[s] + else: + chunks = [] + in_memory_labels = None + in_memory_indices = None + if len(chunks) > 0: # handle any leftover data res = [ - in_memory_data, + vstack(chunks), in_memory_labels if self.labels is not None else None, ] if self._return_index: - res += [in_memory_indices[s]] + res += [in_memory_indices] yield tuple(res) From 124dfdb878ca3863798c4d3c10bbc47f4dee5a9f Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 19 Aug 2025 12:44:49 +0000 Subject: [PATCH 2/7] fix: remove shape hotpath --- arrayloaders/io/zarr_loader.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index b63a404..baea703 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -82,6 +82,7 @@ class AnnDataManager(Generic[OnDiskArray, InMemoryArray]): _return_index: bool = False _on_add: Callable | None = None _batch_size: int = 1 + _shapes: list[tuple[int, int]] = [] def __init__( self, @@ -100,11 +101,11 @@ def dataset_type(self) -> type[OnDiskArray]: @property def n_obs(self) -> int: - return sum(ds.shape[0] for ds in self.train_datasets) + return sum(shape[0] for shape in self._shapes) @property def n_var(self) -> int: - return self.train_datasets[0].shape[1] + return self._shapes[0][1] def add_anndatas( self, @@ -150,7 +151,7 @@ def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> No ) datasets = self.train_datasets + [dataset] check_var_shapes(datasets) - self._var_size = datasets[0].shape[1] # TODO: joins + self._shapes += [dataset.shape] self.train_datasets = datasets if self.labels is not None: # labels exist self.labels += [obs] @@ -184,9 +185,8 @@ def _get_relative_obs_indices( max_idx = index.stop curr_pos = 0 slices = [] - for idx, array in enumerate(self.train_datasets): + for idx, (n_obs, _) in enumerate(self._shapes): array_start = curr_pos - n_obs = array.shape[0] array_end = curr_pos + n_obs start = max(min_idx, array_start) From 171c4729f4cdb2f71da13ecc35d8683fe4d818a9 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 19 Aug 2025 13:06:39 +0000 Subject: [PATCH 3/7] fix: use array --- arrayloaders/io/zarr_loader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index 1e9bd25..08f5744 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -27,7 +27,7 @@ def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array) accepted_on_disk_types = OnDiskArray.__constraints__ -InMemoryArray = TypeVar("InMemoryArray", sp.csr_matrix, np.ndarray) +InMemoryArray = TypeVar("InMemoryArray", sp.csr_array, np.ndarray) def _batched(iterable, n): @@ -339,7 +339,7 @@ def iter( # Do batch returns, handling leftover data as necessary vstack = ( sp.vstack - if isinstance(chunks[0], sp.csr_matrix | sp.csr_array) + if isinstance(chunks[0], sp.csr_array | sp.csr_array) else np.vstack ) if self.labels is not None: @@ -463,7 +463,7 @@ def __init__( raise NotImplementedError( "If you need batch loading that is bigger than the iterated in-memory size, please open an issue." ) - self._dataset_manager: AnnDataManager[ad.abc.CSRDataset, sp.csr_matrix] = ( + self._dataset_manager: AnnDataManager[ad.abc.CSRDataset, sp.csr_array] = ( AnnDataManager( # TODO: https://github.com/scverse/anndata/issues/2021 # on_add=self._cache_update_callback, @@ -643,7 +643,7 @@ async def _fetch_data( self, slices: list[slice], dataset_idx: int, - ) -> sp.csr_matrix: + ) -> sp.csr_array: # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295 # for the inspiration of this function. indptr, indices, data = await self._get_sparse_elems(dataset_idx) @@ -669,7 +669,7 @@ async def _fetch_data( offsets = accumulate(chain([indptr_limits[0].start], gaps)) start_indptr = indptr_indices[0] - next(offsets) if len(slices) < 2: # there is only one slice so no need to concatenate - return sp.csr_matrix( + return sp.csr_array( (data_np, indices_np, start_indptr), shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var), ) @@ -677,7 +677,7 @@ async def _fetch_data( [s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)] ) indptr_np = np.concatenate([start_indptr, end_indptr]) - return sp.csr_matrix( + return sp.csr_array( (data_np, indices_np, indptr_np), shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var), ) From 4a7518e05bddc403c3947f404fe403f065eb82d0 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 21 Aug 2025 23:09:35 +0000 Subject: [PATCH 4/7] feat: cupy early --- arrayloaders/io/zarr_loader.py | 119 +++++++++++++++++++++++---------- tests/test_dataset_loading.py | 52 ++++++++++++-- 2 files changed, 130 insertions(+), 41 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index 08f5744..a79aa40 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -4,6 +4,7 @@ import math from abc import ABCMeta, abstractmethod from collections import OrderedDict, defaultdict +from dataclasses import dataclass from itertools import accumulate, chain, islice, pairwise from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar, cast @@ -18,6 +19,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator + from types import ModuleType from typing import Self @@ -27,7 +29,15 @@ def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array) accepted_on_disk_types = OnDiskArray.__constraints__ -InMemoryArray = TypeVar("InMemoryArray", sp.csr_array, np.ndarray) + + +@dataclass +class CSRContainer: + elems: tuple[np.ndarray, np.ndarray, np.ndarray] + shape: tuple[int, int] + + +InMemoryArray = TypeVar("InMemoryArray", sp.csr_matrix, np.ndarray) def _batched(iterable, n): @@ -59,29 +69,6 @@ async def index_datasets( return await asyncio.gather(*tasks) -def reindex_against_integer_indices( - indices: np.ndarray, chunks: list[InMemoryArray] -) -> tuple[np.ndarray, list[InMemoryArray]]: - upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) - lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]]) - reindexed, chunks_reindexed = list( - zip( - *( - (reindexed, c[reindexed - lower]) - for c, upper, lower in zip( - chunks, upper_bounds, lower_bounds, strict=False - ) - if (reindexed := indices[(indices < upper) & (indices >= lower)]).shape[ - 0 - ] - > 0 - ), - strict=False, - ) - ) - return np.concatenate(reindexed), list(chunks_reindexed) - - add_dataset_docstring = """\ Append datasets to this loader. @@ -106,6 +93,8 @@ class AnnDataManager(Generic[OnDiskArray, InMemoryArray]): _on_add: Callable | None = None _batch_size: int = 1 _shapes: list[tuple[int, int]] = [] + sp_module: ModuleType + np_module: ModuleType def __init__( self, @@ -113,10 +102,25 @@ def __init__( on_add: Callable | None = None, return_index: bool = False, batch_size: int = 1, + use_cupy: bool = False, ): self._on_add = on_add self._return_index = return_index self._batch_size = batch_size + if use_cupy: + try: + import cupy as cp + import cupyx.scipy.sparse as cpx # pragma: no cover + + self.sp_module = cpx # pragma: no cover + self.np_module = cp # pragma: no cover + except ImportError: + raise ImportError( + "Cannot find cupy module even though `use_cupy` argument was set to `True`" + ) from None + else: + self.sp_module = sp + self.np_module = np @property def dataset_type(self) -> type[OnDiskArray]: @@ -174,7 +178,7 @@ def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> No ) datasets = self.train_datasets + [dataset] check_var_shapes(datasets) - self._shapes += [dataset.shape] + self._shapes = self._shapes + [dataset.shape] self.train_datasets = datasets if self.labels is not None: # labels exist self.labels += [obs] @@ -258,7 +262,7 @@ def _get_chunks( Returns: A :class:`numpy.ndarray` of chunk ids. """ - chunks = np.array(list(range(math.ceil(self.n_obs / chunk_size)))) + chunks = np.arange(math.ceil(self.n_obs / chunk_size)) if shuffle: worker_handle.shuffle(chunks) @@ -270,7 +274,9 @@ def iter( worker_handle: WorkerHandle, preload_nchunks: int, shuffle: bool, - fetch_data: Callable[[list[slice], int], Awaitable[InMemoryArray]], + fetch_data: Callable[ + [list[slice], int], Awaitable[InMemoryArray | CSRContainer] + ], ) -> Iterator[ tuple[InMemoryArray, None | np.ndarray] | tuple[InMemoryArray, None | np.ndarray, np.ndarray] @@ -302,6 +308,22 @@ def iter( dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) # Fetch the data over slices chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) + if any(isinstance(c, CSRContainer) for c in chunks): + chunks = [ + self.sp_module.csr_matrix( + tuple(self.np_module.array(e) for e in c.elems), shape=c.shape + ) + if not isinstance(c, self.sp_module.csr_matrix) + else c + for c in chunks + ] + else: + chunks = [ + self.np_module.array(c) + if not isinstance(c, self.np_module.ndarray) + else c + for c in chunks + ] # Accumulate labels labels: None | list[np.ndarray] = None if self.labels is not None: @@ -338,8 +360,8 @@ def iter( ] # Do batch returns, handling leftover data as necessary vstack = ( - sp.vstack - if isinstance(chunks[0], sp.csr_array | sp.csr_array) + self.sp_module.vstack + if isinstance(chunks[0], self.sp_module.csr_matrix) else np.vstack ) if self.labels is not None: @@ -363,7 +385,9 @@ def iter( np.random.default_rng().shuffle(batch_indices) splits = split_given_size(batch_indices, self._batch_size) for i, s in enumerate(splits): - s, chunks_reindexed = reindex_against_integer_indices(s, chunks) + s, chunks_reindexed = self.reindex_against_integer_indices( + s, chunks + ) if s.shape[0] == self._batch_size: res = [ vstack(chunks_reindexed), @@ -394,6 +418,28 @@ def iter( res += [in_memory_indices] yield tuple(res) + def reindex_against_integer_indices( + self, indices: np.ndarray, chunks: list[InMemoryArray] + ) -> tuple[np.ndarray, list[InMemoryArray]]: + upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) + lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]]) + reindexed, chunks_reindexed = list( + zip( + *( + (reindexed, c[self.np_module.asarray(reindexed - lower)]) + for c, upper, lower in zip( + chunks, upper_bounds, lower_bounds, strict=False + ) + if ( + reindexed := indices[(indices < upper) & (indices >= lower)] + ).shape[0] + > 0 + ), + strict=False, + ) + ) + return np.concatenate(reindexed), list(chunks_reindexed) + AnnDataManager.add_datasets.__doc__ = add_dataset_docstring AnnDataManager.add_dataset.__doc__ = add_dataset_docstring @@ -409,6 +455,7 @@ def iter( preload_nchunks: The number of chunks of contiguous array data to fetch, by default 32 shuffle: Whether or not to shuffle the data, by default True return_index: Whether or not to return the index on each iteration, by default False + use_cupy: Whether or not to use cupy for non-io array operations like vstack and indexing. This option entails greater GPU memory usage. """ @@ -451,6 +498,7 @@ def __init__( shuffle: bool = True, return_index: bool = False, batch_size: int = 1, + use_cupy: bool = False, ): check_lt_1( [ @@ -469,6 +517,7 @@ def __init__( # on_add=self._cache_update_callback, return_index=return_index, batch_size=batch_size, + use_cupy=use_cupy, ) ) self._chunk_size = chunk_size @@ -643,7 +692,7 @@ async def _fetch_data( self, slices: list[slice], dataset_idx: int, - ) -> sp.csr_array: + ) -> CSRContainer: # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295 # for the inspiration of this function. indptr, indices, data = await self._get_sparse_elems(dataset_idx) @@ -669,16 +718,16 @@ async def _fetch_data( offsets = accumulate(chain([indptr_limits[0].start], gaps)) start_indptr = indptr_indices[0] - next(offsets) if len(slices) < 2: # there is only one slice so no need to concatenate - return sp.csr_array( - (data_np, indices_np, start_indptr), + return CSRContainer( + elems=(data_np, indices_np, start_indptr), shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var), ) end_indptr = np.concatenate( [s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)] ) indptr_np = np.concatenate([start_indptr, end_indptr]) - return sp.csr_array( - (data_np, indices_np, indptr_np), + return CSRContainer( + elems=(data_np, indices_np, indptr_np), shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var), ) diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py index 226b68c..da22d1f 100644 --- a/tests/test_dataset_loading.py +++ b/tests/test_dataset_loading.py @@ -1,6 +1,8 @@ from __future__ import annotations import platform +from importlib.util import find_spec +from types import NoneType from typing import TYPE_CHECKING, TypedDict import anndata as ad @@ -17,6 +19,13 @@ read_lazy_store, ) +try: + from cupy import ndarray as CupyArray + from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix +except ImportError: + CupyCSRMatrix = NoneType + CupyArray = NoneType + if TYPE_CHECKING: from pathlib import Path @@ -91,12 +100,14 @@ def concat(dicts: list[Data]) -> ListData: chunk_size=chunk_size, preload_nchunks=preload_nchunks, dataset_class=dataset_class, - batch_size=batch_size: dataset_class( + batch_size=batch_size, + use_cupy=use_cupy: dataset_class( shuffle=shuffle, chunk_size=chunk_size, preload_nchunks=preload_nchunks, return_index=True, batch_size=batch_size, + use_cupy=use_cupy, ).add_datasets( **concat( [ @@ -109,14 +120,26 @@ def concat(dicts: list[Data]) -> ListData: ] ) ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if use_cupy else ''}", # type: ignore[attr-defined] + marks=pytest.mark.skipif( + find_spec("cupy") is None and use_cupy, reason="need cupy installed" + ), ) - for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size in [ + for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size, use_cupy in [ elem + for use_cupy in [True, False] for dataset_class in [ZarrDenseDataset, ZarrSparseDataset] # type: ignore[list-item] for elem in [ - [1, 5, None, dataset_class, None, 1], # singleton chunk size - [5, 1, None, dataset_class, None, 1], # singleton preload + [ + 1, + 5, + None, + dataset_class, + None, + 1, + use_cupy, + ], # singleton chunk size + [5, 1, None, dataset_class, None, 1, use_cupy], # singleton preload [ 10, 5, @@ -124,6 +147,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 5, + use_cupy, ], # batch size divides total in memory size evenly [ 10, @@ -132,6 +156,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 50, + use_cupy, ], # batch size equal to in-memory size loading [ 10, @@ -140,6 +165,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 15, + use_cupy, ], # batch size does not divide in memory size evenly ] ] @@ -174,7 +200,7 @@ def test_store_load_dataset(mock_store: Path, *, shuffle: bool, gen_loader, use_ n_elems += 1 if (is_dask := isinstance(loader, DaskDataset)) else x.shape[0] # Check feature dimension assert x.shape[0 if is_dask else 1] == 100 - batches += [x] + batches += [x.get() if isinstance(x, CupyCSRMatrix | CupyArray) else x] if label is not None: labels += [label] if index is not None: @@ -326,3 +352,17 @@ def test_torch_multiprocess_dataloading_zarr(mock_store, loader, use_zarrs): idxs = np.concatenate(idx_list) assert np.array_equal(x[np.argsort(idxs)], x_ref) + + +@pytest.mark.skipif( + find_spec("cupy") is not None, reason="Can't test for no cupy if cupy is there" +) +def test_no_cupy(): + with pytest.raises(ImportError, match=r"even though `use_cupy` argument"): + ZarrSparseDataset( + chunk_size=10, + preload_nchunks=4, + shuffle=True, + return_index=True, + use_cupy=True, + ) From 946fd658b0c3cfa49a3d9014c882de294f09ca1d Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Fri, 22 Aug 2025 12:09:11 +0000 Subject: [PATCH 5/7] Merge branch 'main' into ig/no_vstack_cupy --- arrayloaders/io/zarr_loader.py | 104 ++++++++++++++++++--------------- pyproject.toml | 1 + tests/test_dataset_loading.py | 35 ++++++----- 3 files changed, 81 insertions(+), 59 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index a79aa40..52c9be6 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -6,6 +6,7 @@ from collections import OrderedDict, defaultdict from dataclasses import dataclass from itertools import accumulate, chain, islice, pairwise +from types import NoneType from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar, cast import anndata as ad @@ -17,6 +18,13 @@ from .utils import WorkerHandle, check_lt_1, check_var_shapes +try: + from cupy import ndarray as CupyArray + from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover +except ImportError: + CupyCSRMatrix = NoneType + CupyArray = NoneType + if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterator from types import ModuleType @@ -24,6 +32,8 @@ def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: + if size == 1: + return a return np.split(a, np.arange(size, len(a), size)) @@ -37,7 +47,10 @@ class CSRContainer: shape: tuple[int, int] -InMemoryArray = TypeVar("InMemoryArray", sp.csr_matrix, np.ndarray) +OutputInMemoryArray = TypeVar( + "OutputInMemoryArray", sp.csr_matrix, np.ndarray, CupyCSRMatrix, CupyArray +) +InputInMemoryArray = TypeVar("InputInMemoryArray", CSRContainer, np.ndarray) def _batched(iterable, n): @@ -50,8 +63,8 @@ def _batched(iterable, n): async def index_datasets( dataset_index_to_slices: OrderedDict[int, list[slice]], - fetch_data: Callable[[list[slice], int], Awaitable[InMemoryArray]], -) -> list[InMemoryArray]: + fetch_data: Callable[[list[slice], int], Awaitable[CSRContainer | np.ndarray]], +) -> list[InputInMemoryArray]: """Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr. Args: @@ -86,7 +99,7 @@ async def index_datasets( """ -class AnnDataManager(Generic[OnDiskArray, InMemoryArray]): +class AnnDataManager(Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray]): train_datasets: list[OnDiskArray] = [] labels: list[np.ndarray] | None = None _return_index: bool = False @@ -102,12 +115,12 @@ def __init__( on_add: Callable | None = None, return_index: bool = False, batch_size: int = 1, - use_cupy: bool = False, + preload_to_gpu: bool = False, ): self._on_add = on_add self._return_index = return_index self._batch_size = batch_size - if use_cupy: + if preload_to_gpu: try: import cupy as cp import cupyx.scipy.sparse as cpx # pragma: no cover @@ -116,7 +129,7 @@ def __init__( self.np_module = cp # pragma: no cover except ImportError: raise ImportError( - "Cannot find cupy module even though `use_cupy` argument was set to `True`" + "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`" ) from None else: self.sp_module = sp @@ -274,12 +287,10 @@ def iter( worker_handle: WorkerHandle, preload_nchunks: int, shuffle: bool, - fetch_data: Callable[ - [list[slice], int], Awaitable[InMemoryArray | CSRContainer] - ], + fetch_data: Callable[[list[slice], int], Awaitable[np.ndarray | CSRContainer]], ) -> Iterator[ - tuple[InMemoryArray, None | np.ndarray] - | tuple[InMemoryArray, None | np.ndarray, np.ndarray] + tuple[InputInMemoryArray, None | np.ndarray] + | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] ]: """Iterate over the on-disk csr datasets. @@ -292,9 +303,14 @@ def iter( ) # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 # we must keep track of the leftover data. - chunks: list[InMemoryArray] = [] + chunks: list[InputInMemoryArray] = [] in_memory_labels = None in_memory_indices = None + vstack = ( + self.sp_module.vstack + if issubclass(self.dataset_type, ad.abc.CSRDataset) + else np.vstack + ) for chunk_indices in _batched( self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks ): @@ -309,21 +325,16 @@ def iter( # Fetch the data over slices chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) if any(isinstance(c, CSRContainer) for c in chunks): - chunks = [ + chunks_converted: list[OutputInMemoryArray] = [ self.sp_module.csr_matrix( - tuple(self.np_module.array(e) for e in c.elems), shape=c.shape + tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape ) - if not isinstance(c, self.sp_module.csr_matrix) + if isinstance(c, CSRContainer) else c for c in chunks ] else: - chunks = [ - self.np_module.array(c) - if not isinstance(c, self.np_module.ndarray) - else c - for c in chunks - ] + chunks_converted = [self.np_module.asarray(c) for c in chunks] # Accumulate labels labels: None | list[np.ndarray] = None if self.labels is not None: @@ -359,11 +370,6 @@ def iter( for index in dataset_indices ] # Do batch returns, handling leftover data as necessary - vstack = ( - self.sp_module.vstack - if isinstance(chunks[0], self.sp_module.csr_matrix) - else np.vstack - ) if self.labels is not None: in_memory_labels = ( np.concatenate(labels) @@ -386,11 +392,13 @@ def iter( splits = split_given_size(batch_indices, self._batch_size) for i, s in enumerate(splits): s, chunks_reindexed = self.reindex_against_integer_indices( - s, chunks + s, chunks_converted ) if s.shape[0] == self._batch_size: res = [ - vstack(chunks_reindexed), + vstack(chunks_reindexed) + if len(chunks_reindexed) > 1 + else chunks_reindexed[0], in_memory_labels[s] if self.labels is not None else None, ] if self._return_index: @@ -419,8 +427,8 @@ def iter( yield tuple(res) def reindex_against_integer_indices( - self, indices: np.ndarray, chunks: list[InMemoryArray] - ) -> tuple[np.ndarray, list[InMemoryArray]]: + self, indices: np.ndarray, chunks: list[OutputInMemoryArray] + ) -> tuple[np.ndarray, list[OutputInMemoryArray]]: upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) lower_bounds = np.concatenate([np.array([0]), upper_bounds[:-1]]) reindexed, chunks_reindexed = list( @@ -455,7 +463,7 @@ def reindex_against_integer_indices( preload_nchunks: The number of chunks of contiguous array data to fetch, by default 32 shuffle: Whether or not to shuffle the data, by default True return_index: Whether or not to return the index on each iteration, by default False - use_cupy: Whether or not to use cupy for non-io array operations like vstack and indexing. This option entails greater GPU memory usage. + preload_to_gpu: Whether or not to use cupy for non-io array operations like vstack and indexing. This option entails greater GPU memory usage. """ @@ -483,12 +491,16 @@ def __iter__(self): total += gap -class AbstractIterableDataset(Generic[OnDiskArray, InMemoryArray], metaclass=ABCMeta): +class AbstractIterableDataset( + Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray], metaclass=ABCMeta +): _shuffle: bool _preload_nchunks: int _worker_handle: WorkerHandle _chunk_size: int - _dataset_manager: AnnDataManager[OnDiskArray, InMemoryArray] + _dataset_manager: AnnDataManager[ + OnDiskArray, InputInMemoryArray, OutputInMemoryArray + ] def __init__( self, @@ -498,7 +510,7 @@ def __init__( shuffle: bool = True, return_index: bool = False, batch_size: int = 1, - use_cupy: bool = False, + preload_to_gpu: bool = False, ): check_lt_1( [ @@ -511,14 +523,12 @@ def __init__( raise NotImplementedError( "If you need batch loading that is bigger than the iterated in-memory size, please open an issue." ) - self._dataset_manager: AnnDataManager[ad.abc.CSRDataset, sp.csr_array] = ( - AnnDataManager( - # TODO: https://github.com/scverse/anndata/issues/2021 - # on_add=self._cache_update_callback, - return_index=return_index, - batch_size=batch_size, - use_cupy=use_cupy, - ) + self._dataset_manager = AnnDataManager( + # TODO: https://github.com/scverse/anndata/issues/2021 + # on_add=self._cache_update_callback, + return_index=return_index, + batch_size=batch_size, + preload_to_gpu=preload_to_gpu, ) self._chunk_size = chunk_size self._preload_nchunks = preload_nchunks @@ -529,7 +539,9 @@ async def _cache_update_callback(self): pass @abstractmethod - async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> InMemoryArray: + async def _fetch_data( + self, slices: list[slice], dataset_idx: int + ) -> InputInMemoryArray: """Fetch the data for given slices and the arrays representing a dataset on-disk. Args: @@ -573,8 +585,8 @@ def __len__(self) -> int: def __iter__( self, ) -> Iterator[ - tuple[InMemoryArray, None | np.ndarray] - | tuple[InMemoryArray, None | np.ndarray, np.ndarray] + tuple[InputInMemoryArray, None | np.ndarray] + | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] ]: """Iterate over the on-disk datasets. diff --git a/pyproject.toml b/pyproject.toml index 61be08c..2d82f46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dev = [ "pytest-cov", "nbproject_test", ] +gpu = ["cupy-cuda12x"] [tool.pytest.ini_options] testpaths = [ diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py index da22d1f..32c93ae 100644 --- a/tests/test_dataset_loading.py +++ b/tests/test_dataset_loading.py @@ -101,13 +101,13 @@ def concat(dicts: list[Data]) -> ListData: preload_nchunks=preload_nchunks, dataset_class=dataset_class, batch_size=batch_size, - use_cupy=use_cupy: dataset_class( + preload_to_gpu=preload_to_gpu: dataset_class( shuffle=shuffle, chunk_size=chunk_size, preload_nchunks=preload_nchunks, return_index=True, batch_size=batch_size, - use_cupy=use_cupy, + preload_to_gpu=preload_to_gpu, ).add_datasets( **concat( [ @@ -120,14 +120,15 @@ def concat(dicts: list[Data]) -> ListData: ] ) ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if use_cupy else ''}", # type: ignore[attr-defined] + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=pytest.mark.skipif( - find_spec("cupy") is None and use_cupy, reason="need cupy installed" + find_spec("cupy") is None and preload_to_gpu, + reason="need cupy installed", ), ) - for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size, use_cupy in [ + for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size, preload_to_gpu in [ elem - for use_cupy in [True, False] + for preload_to_gpu in [True, False] for dataset_class in [ZarrDenseDataset, ZarrSparseDataset] # type: ignore[list-item] for elem in [ [ @@ -137,9 +138,17 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 1, - use_cupy, + preload_to_gpu, ], # singleton chunk size - [5, 1, None, dataset_class, None, 1, use_cupy], # singleton preload + [ + 5, + 1, + None, + dataset_class, + None, + 1, + preload_to_gpu, + ], # singleton preload [ 10, 5, @@ -147,7 +156,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 5, - use_cupy, + preload_to_gpu, ], # batch size divides total in memory size evenly [ 10, @@ -156,7 +165,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 50, - use_cupy, + preload_to_gpu, ], # batch size equal to in-memory size loading [ 10, @@ -165,7 +174,7 @@ def concat(dicts: list[Data]) -> ListData: dataset_class, None, 15, - use_cupy, + preload_to_gpu, ], # batch size does not divide in memory size evenly ] ] @@ -358,11 +367,11 @@ def test_torch_multiprocess_dataloading_zarr(mock_store, loader, use_zarrs): find_spec("cupy") is not None, reason="Can't test for no cupy if cupy is there" ) def test_no_cupy(): - with pytest.raises(ImportError, match=r"even though `use_cupy` argument"): + with pytest.raises(ImportError, match=r"even though `preload_to_gpu` argument"): ZarrSparseDataset( chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True, - use_cupy=True, + preload_to_gpu=True, ) From bcb601677a0d58ae89b20faa1ad3f3db8f9c075e Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Fri, 22 Aug 2025 12:21:09 +0000 Subject: [PATCH 6/7] fix: oops, actually yield properly when batch size matches in-memory --- arrayloaders/io/zarr_loader.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index 52c9be6..997f9a3 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -417,14 +417,17 @@ def iter( chunks = [] in_memory_labels = None in_memory_indices = None - if len(chunks) > 0: # handle any leftover data - res = [ - vstack(chunks), - in_memory_labels if self.labels is not None else None, - ] - if self._return_index: - res += [in_memory_indices] - yield tuple(res) + elif len(chunks_converted) > 0: # handle any leftover data + res = [ + vstack(chunks_converted), + in_memory_labels if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices] + yield tuple(res) + chunks = [] + in_memory_labels = None + in_memory_indices = None def reindex_against_integer_indices( self, indices: np.ndarray, chunks: list[OutputInMemoryArray] From c446e63f7bbc208d64af01738c571e2ac74d963f Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Fri, 22 Aug 2025 12:43:46 +0000 Subject: [PATCH 7/7] fix: properly handle leftovers --- arrayloaders/io/zarr_loader.py | 43 +++++++++++++++++++++++----------- tests/test_dataset_loading.py | 10 +++++--- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py index 997f9a3..fc852a7 100644 --- a/arrayloaders/io/zarr_loader.py +++ b/arrayloaders/io/zarr_loader.py @@ -324,17 +324,8 @@ def iter( dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) # Fetch the data over slices chunks += zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) - if any(isinstance(c, CSRContainer) for c in chunks): - chunks_converted: list[OutputInMemoryArray] = [ - self.sp_module.csr_matrix( - tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape - ) - if isinstance(c, CSRContainer) - else c - for c in chunks - ] - else: - chunks_converted = [self.np_module.asarray(c) for c in chunks] + chunks_converted = self._to_output_array(chunks) + # Accumulate labels labels: None | list[np.ndarray] = None if self.labels is not None: @@ -391,7 +382,7 @@ def iter( np.random.default_rng().shuffle(batch_indices) splits = split_given_size(batch_indices, self._batch_size) for i, s in enumerate(splits): - s, chunks_reindexed = self.reindex_against_integer_indices( + s, chunks_reindexed = self._reindex_against_integer_indices( s, chunks_converted ) if s.shape[0] == self._batch_size: @@ -417,7 +408,7 @@ def iter( chunks = [] in_memory_labels = None in_memory_indices = None - elif len(chunks_converted) > 0: # handle any leftover data + elif len(chunks_converted) > 0: # handle batch size matches in-memory res = [ vstack(chunks_converted), in_memory_labels if self.labels is not None else None, @@ -428,8 +419,32 @@ def iter( chunks = [] in_memory_labels = None in_memory_indices = None + if len(chunks) > 0: # handle leftover data + res = [ + vstack(self._to_output_array(chunks)), + np.asarray(in_memory_labels) if self.labels is not None else None, + ] + if self._return_index: + res += [np.asarray(in_memory_indices)] + yield tuple(res) + + def _to_output_array( + self, chunks: list[InputInMemoryArray | OutputInMemoryArray] + ) -> list[OutputInMemoryArray]: + if any(isinstance(c, CSRContainer) for c in chunks): + return [ + self.sp_module.csr_matrix( + tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape + ) + if isinstance(c, CSRContainer) + else c + for c in chunks + ] + elif any(isinstance(c, np.ndarray) for c in chunks): + return [self.np_module.asarray(c) for c in chunks] + return chunks - def reindex_against_integer_indices( + def _reindex_against_integer_indices( self, indices: np.ndarray, chunks: list[OutputInMemoryArray] ) -> tuple[np.ndarray, list[OutputInMemoryArray]]: upper_bounds = np.cumsum(np.array([c.shape[0] for c in chunks])) diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py index 32c93ae..56d48ca 100644 --- a/tests/test_dataset_loading.py +++ b/tests/test_dataset_loading.py @@ -173,7 +173,7 @@ def concat(dicts: list[Data]) -> ListData: None, dataset_class, None, - 15, + 14, preload_to_gpu, ], # batch size does not divide in memory size evenly ] @@ -192,7 +192,8 @@ def test_store_load_dataset(mock_store: Path, *, shuffle: bool, gen_loader, use_ adata = read_lazy_store(mock_store, obs_columns=["label"]) loader = gen_loader(mock_store, shuffle, use_zarrs) - is_dense = isinstance(loader, ZarrDenseDataset | DaskDataset) + is_dask = isinstance(loader, DaskDataset) + is_dense = isinstance(loader, ZarrDenseDataset) or is_dask n_elems = 0 batches = [] labels = [] @@ -223,7 +224,10 @@ def test_store_load_dataset(mock_store: Path, *, shuffle: bool, gen_loader, use_ np.testing.assert_allclose(stacked, expected_data) if len(labels) > 0: expected_labels = adata.obs["label"] - np.testing.assert_allclose(np.array(labels).ravel(), expected_labels) + np.testing.assert_allclose( + (np.array(labels) if is_dask else np.concatenate(labels)).ravel(), + expected_labels, + ) else: if len(indices) > 0: indices = np.concatenate(indices).ravel()