Skip to content

sparse array support #3563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3563.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for {class}`scipy.sparse.csr_array` and {class}`scipy.sparse.csc_array` {smaller}`P Angerer`
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,11 @@ required-imports = [ "from __future__ import annotations" ]
"pandas.api.types.is_categorical_dtype".msg = "Use isinstance(s.dtype, CategoricalDtype) instead"
"pandas.value_counts".msg = "Use pd.Series(a).value_counts() instead"
"scipy.sparse.spmatrix".msg = "Use _compat.SpBase instead"
"scipy.sparse.sparray".msg = "Use _compat.SpBase instead"
"scipy.sparse.csr_matrix".msg = "Use _compat.CSRBase or _compat.CSBase for typing/type checks and add `# noqa: TID251` when constructing"
"scipy.sparse.csc_matrix".msg = "Use _compat.CSCBase or _compat.CSBase for typing/type checks and add `# noqa: TID251` when constructing"
"scipy.sparse.csr_array".msg = "Use _compat.CSRBase or _compat.CSBase for typing/type checks and add `# noqa: TID251` when constructing"
"scipy.sparse.csc_array".msg = "Use _compat.CSCBase or _compat.CSBase for typing/type checks and add `# noqa: TID251` when constructing"
"scipy.sparse.issparse".msg = "Use isinstance(_, _compat.CSBase) or isinstance(_, _compat.SpBase) instead"
"legacy_api_wrap.legacy_api".msg = "Use scanpy._compat.old_positionals instead"
"numpy.bool".msg = "Use `np.bool_` instead for numpy>=1.24<2 compatibility"
Expand Down
45 changes: 27 additions & 18 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,39 @@
from importlib.metadata import PackageMetadata


__all__ = [
"CSBase",
"CSCBase",
"CSRBase",
"DaskArray",
"SpBase",
"ZappyArray",
"_numba_threading_layer",
"deprecated",
"fullname",
"njit",
"old_positionals",
"pkg_metadata",
"pkg_version",
]


P = ParamSpec("P")
R = TypeVar("R")


SpBase = sparse.spmatrix | sparse.sparray # noqa: TID251
"""Only use when you directly convert it to a known subclass."""

_CSArray = sparse.csr_array | sparse.csc_array # noqa: TID251
"""Only use if you want to specially handle arrays as opposed to matrices."""

_CSMatrix = sparse.csr_matrix | sparse.csc_matrix # noqa: TID251
"""Only use if you want to specially handle matrices as opposed to arrays"""
"""Only use if you want to specially handle matrices as opposed to arrays."""

CSRBase = sparse.csr_matrix # noqa: TID251
CSCBase = sparse.csc_matrix # noqa: TID251
SpBase = sparse.spmatrix # noqa: TID251
CSBase = _CSMatrix
CSRBase = sparse.csr_matrix | sparse.csr_array # noqa: TID251
CSCBase = sparse.csc_matrix | sparse.csc_array # noqa: TID251
CSBase = _CSArray | _CSMatrix


if TYPE_CHECKING:
Expand All @@ -44,19 +66,6 @@
ZappyArray.__module__ = "zappy.base"


__all__ = [
"DaskArray",
"ZappyArray",
"_numba_threading_layer",
"deprecated",
"fullname",
"njit",
"old_positionals",
"pkg_metadata",
"pkg_version",
]


def fullname(typ: type) -> str:
module = typ.__module__
name = typ.__qualname__
Expand Down
24 changes: 16 additions & 8 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@
import numpy as np
from anndata import __version__ as anndata_version
from packaging.version import Version
from scipy import sparse

from .. import logging as logg
from .._compat import CSBase, DaskArray, _CSMatrix
from .._compat import CSBase, DaskArray, _CSArray, _CSMatrix, pkg_version
from .._settings import settings
from .compute.is_constant import is_constant # noqa: F401

Expand Down Expand Up @@ -639,9 +638,7 @@
if out is not None:
X.data = new_data_op(X)
return X
return sparse.csr_matrix( # noqa: TID251
(new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape
)
return type(X)((new_data_op(X), indices.copy(), indptr.copy()), shape=X.shape)
transposed = X.T
return axis_mul_or_truediv(
transposed,
Expand Down Expand Up @@ -722,9 +719,20 @@
return np.count_nonzero(X, axis=axis)


@axis_nnz.register(CSBase)
def _(X: CSBase, axis: Literal[0, 1]) -> np.ndarray:
return X.getnnz(axis=axis)
if pkg_version("scipy") >= Version("1.15"):
# newer scipy versions support the `axis` argument for count_nonzero
@axis_nnz.register(CSBase)
def _(X: CSBase, axis: Literal[0, 1]) -> np.ndarray:
return X.count_nonzero(axis=axis)
else:
# older scipy versions don’t have any way to get the nnz of a sparse array
@axis_nnz.register(CSBase)
def _(X: CSBase, axis: Literal[0, 1]) -> np.ndarray:
if isinstance(X, _CSArray):
from scipy.sparse import csc_array, csr_array # noqa: TID251

Check warning on line 732 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L732

Added line #L732 was not covered by tests

X = (csr_array if X.format == "csr" else csc_array)(X)

Check warning on line 734 in src/scanpy/_utils/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/_utils/__init__.py#L734

Added line #L734 was not covered by tests
return X.getnnz(axis=axis)


@axis_nnz.register(DaskArray)
Expand Down
7 changes: 5 additions & 2 deletions src/scanpy/metrics/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy as np
import pandas as pd
from scipy import sparse

from .._compat import CSRBase, DaskArray, SpBase, fullname

Expand Down Expand Up @@ -108,7 +107,11 @@ def _(

@_resolve_vals.register(SpBase)
def _(val: SpBase) -> CSRBase:
return sparse.csr_matrix(val) # noqa: TID251
if TYPE_CHECKING:
from scipy.sparse._base import _spbase

assert isinstance(val, _spbase)
return val.tocsr()


@_resolve_vals.register(pd.DataFrame)
Expand Down
9 changes: 2 additions & 7 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy import sparse

from .. import logging as logg
from .._compat import CSBase, CSRBase, DaskArray, old_positionals
from .._compat import CSBase, DaskArray, old_positionals
from .._settings import Verbosity, settings
from .._utils import check_nonnegative_integers, sanitize_anndata
from ..get import _get_obs_rep
Expand Down Expand Up @@ -103,11 +102,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
vmax = np.sqrt(N)
clip_val = reg_std * vmax + mean
if isinstance(data_batch, CSBase):
if isinstance(data_batch, CSRBase):
batch_counts = data_batch
else:
batch_counts = sparse.csr_matrix(data_batch) # noqa: TID251

batch_counts = data_batch.tocsr()
squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped(
batch_counts.indices,
batch_counts.data,
Expand Down
12 changes: 7 additions & 5 deletions src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def describe_var(
use_raw: bool = False,
inplace: bool = False,
log1p: bool = True,
X: CSBase | sparse.coo_matrix | np.ndarray | None = None,
X: CSBase | np.ndarray | None = None,
) -> pd.DataFrame | None:
"""Describe variables of anndata.

Expand Down Expand Up @@ -314,7 +314,9 @@ def calculate_qc_metrics(
return obs_metrics, var_metrics


def top_proportions(mtx: np.ndarray | CSBase | sparse.coo_matrix, n: int):
def top_proportions(
mtx: np.ndarray | CSBase | sparse.coo_matrix | sparse.coo_array, n: int
):
"""Calculate cumulative proportions of top expressed genes.

Parameters
Expand All @@ -327,9 +329,9 @@ def top_proportions(mtx: np.ndarray | CSBase | sparse.coo_matrix, n: int):
expressed gene.

"""
if isinstance(mtx, CSBase | sparse.coo_matrix):
if isinstance(mtx, CSBase | sparse.coo_matrix | sparse.coo_array):
if not isinstance(mtx, CSRBase):
mtx = sparse.csr_matrix(mtx) # noqa: TID251
mtx = mtx.tocsr()
# Allowing numba to do more
return top_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(n))
else:
Expand Down Expand Up @@ -427,7 +429,7 @@ def _(mtx: DaskArray, ns: Collection[int]) -> DaskArray:
@check_ns
def _(mtx: CSBase | sparse.coo_matrix, ns: Collection[int]) -> DaskArray:
if not isinstance(mtx, CSRBase):
mtx = sparse.csr_matrix(mtx) # noqa: TID251
mtx = mtx.tocsr()
return top_segment_proportions_sparse_csr(mtx.data, mtx.indptr, np.array(ns))


Expand Down
5 changes: 2 additions & 3 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
from anndata import AnnData
from pandas.api.types import CategoricalDtype
from scipy import sparse
from sklearn.utils import check_array, sparsefuncs

from .. import logging as logg
Expand Down Expand Up @@ -1051,7 +1050,7 @@ def _downsample_per_cell(
if isinstance(X, CSBase):
original_type = type(X)
if not isinstance(X, CSRBase):
X = sparse.csr_matrix(X) # noqa: TID251
X = X.tocsr()
totals = np.ravel(axis_sum(X, axis=1)) # Faster for csr matrix
under_target = np.nonzero(totals > counts_per_cell)[0]
rows = np.split(X.data, X.indptr[1:-1])
Expand Down Expand Up @@ -1096,7 +1095,7 @@ def _downsample_total_counts(
if isinstance(X, CSBase):
original_type = type(X)
if not isinstance(X, CSRBase):
X = sparse.csr_matrix(X) # noqa: TID251
X = X.tocsr()
_downsample_array(
X.data,
total_counts,
Expand Down
13 changes: 11 additions & 2 deletions src/testing/scanpy/_pytest/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

from importlib.metadata import version
from typing import TYPE_CHECKING

import pytest
from anndata.tests.helpers import asarray
from packaging.version import Version
from scipy import sparse

from .._helpers import (
Expand All @@ -21,6 +23,12 @@
from _pytest.mark.structures import ParameterSet


skipif_no_sparray = pytest.mark.skipif(
Version(version("anndata")) < Version("0.11"),
reason="scipy cs{rc}_array not supported in anndata<0.11",
)


def param_with(
at: ParameterSet,
transform: Callable[..., Iterable[Any]] = lambda x: (x,),
Expand All @@ -39,8 +47,9 @@ def param_with(
] = {
("mem", "dense"): (pytest.param(asarray, id="numpy_ndarray"),),
("mem", "sparse"): (
pytest.param(sparse.csr_matrix, id="scipy_csr"), # noqa: TID251
pytest.param(sparse.csc_matrix, id="scipy_csc"), # noqa: TID251
pytest.param(sparse.csr_matrix, id="scipy_csr_mat"), # noqa: TID251
pytest.param(sparse.csc_matrix, id="scipy_csc_mat"), # noqa: TID251
pytest.param(sparse.csr_array, id="scipy_csr_arr", marks=[skipif_no_sparray]), # noqa: TID251
),
("dask", "dense"): (
pytest.param(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_qc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_qc_metrics(adata_prepared: AnnData):
else adata_prepared.X
)
max_X = X.max(axis=0)
if isinstance(max_X, sparse.coo_matrix):
if isinstance(max_X, sparse.coo_matrix | sparse.coo_array):
max_X = max_X.toarray()
elif isinstance(max_X, DaskArray):
max_X = max_X.compute()
Expand Down
Loading