Skip to content

Commit

Permalink
[python] Remove AxisName.getattr_from from ExperimentAxisQuery (#…
Browse files Browse the repository at this point in the history
…3557) (#3577)

The `ExperimentAxisQuery` used the `AxisName` enumeration to access attributes with either `obs` or `var` in the name. This PR directly calls the appropriate functions, primarily by pushing up the calls to one function higher (e.g. pass in `self.indexer.by_obs`/`self.indexer.by_var` instead of `AxisName.OBS`/`AxisName.VAR`).

This change allowed `ExperimentAxisQuery._convert_to_ndarray` and `ExperimentAxisQuery._axism_inner_ndarray` to be merged into a single non-member function, and `ExperimentAxisQuery._axisp_inner_sparray` to be replaced with a direct call to `_read_as_csr`.

Co-authored-by: Julia Dark <[email protected]>
  • Loading branch information
github-actions[bot] and jp-dark authored Jan 16, 2025
1 parent 1542d1a commit 7987429
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 101 deletions.
138 changes: 51 additions & 87 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
Callable,
Dict,
Literal,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
overload,
)

import attrs
Expand Down Expand Up @@ -88,30 +85,7 @@ class AxisName(enum.Enum):

@property
def value(self) -> Literal["obs", "var"]:
return super().value # type: ignore[no-any-return]

@overload
def getattr_from(self, __source: _HasObsVar[_T]) -> _T: ...

@overload
def getattr_from(
self, __source: Any, *, pre: Literal[""], suf: Literal[""]
) -> object: ...

@overload
def getattr_from(
self, __source: Any, *, pre: str = ..., suf: str = ...
) -> object: ...

def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object:
"""Equivalent to ``something.<pre><obs/var><suf>``."""
return getattr(__source, pre + self.value + suf)

def getitem_from(
self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
) -> _T:
"""Equivalent to ``something[pre + "obs"/"var" + suf]``."""
return __source[pre + self.value + suf]
return super().value


@attrs.define
Expand Down Expand Up @@ -389,7 +363,7 @@ def obs_scene_ids(self) -> pa.Array:
)

full_table = obs_scene.read(
coords=((AxisName.OBS.getattr_from(self._joinids), slice(None))),
coords=(self._joinids.obs, slice(None)),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand All @@ -416,7 +390,7 @@ def var_scene_ids(self) -> pa.Array:
)

full_table = var_scene.read(
coords=((AxisName.VAR.getattr_from(self._joinids), slice(None))),
coords=(self._joinids.var, slice(None)),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand Down Expand Up @@ -477,6 +451,8 @@ def to_anndata(
obs_table, var_table = tp.map(
self._read_axis_dataframe,
(AxisName.OBS, AxisName.VAR),
(self._obs_df, self._var_df),
(self._matrix_axis_query.obs, self._matrix_axis_query.var),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
Expand All @@ -496,19 +472,43 @@ def to_anndata(
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
key: tp.submit(
_read_inner_ndarray,
self._get_annotation_layer("obsm", key),
obs_joinids,
self.indexer.by_obs,
)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
key: tp.submit(
_read_inner_ndarray,
self._get_annotation_layer("varm", key),
var_joinids,
self.indexer.by_var,
)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
key: tp.submit(
_read_as_csr,
self._get_annotation_layer("obsp", key),
obs_joinids,
obs_joinids,
self.indexer.by_obs,
self.indexer.by_obs,
)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
key: tp.submit(
_read_as_csr,
self._get_annotation_layer("varp", key),
var_joinids,
var_joinids,
self.indexer.by_var,
self.indexer.by_var,
)
for key in varp_layers
}

Expand Down Expand Up @@ -778,15 +778,13 @@ def __exit__(self, *_: Any) -> None:
def _read_axis_dataframe(
self,
axis: AxisName,
axis_df: DataFrame,
axis_query: AxisQuery,
axis_column_names: AxisColumnNames,
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
column_names = axis_column_names.get(axis.value)

axis_df = axis.getattr_from(self, pre="_", suf="_df")
assert isinstance(axis_df, DataFrame)
axis_query = axis.getattr_from(self._matrix_axis_query)

# If we can cache join IDs, prepare to add them to the cache.
joinids_cached = self._joinids._is_cached(axis)
query_columns = column_names
Expand Down Expand Up @@ -859,56 +857,6 @@ def _get_annotation_layer(
)
return layer

def _convert_to_ndarray(
self, axis: AxisName, table: pa.Table, n_row: int, n_col: int
) -> npt.NDArray[np.float32]:
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
idx = indexer(table["soma_dim_0"])
z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)

def _axisp_inner_sparray(
self,
axis: AxisName,
layer: str,
) -> sp.csr_matrix:
joinids = axis.getattr_from(self._joinids)
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
axis.getattr_from(self.indexer, pre="by_"),
)
annotation_name = f"{axis.value}p"
return _read_as_csr(
self._get_annotation_layer(annotation_name, layer),
joinids,
joinids,
indexer,
indexer,
)

def _axism_inner_ndarray(
self,
axis: AxisName,
layer: str,
) -> npt.NDArray[np.float32]:
joinids = axis.getattr_from(self._joinids)
annotation_name = f"{axis.value}m"
table = (
self._get_annotation_layer(annotation_name, layer)
.read((joinids, slice(None)))
.tables()
.concat()
)

n_row = len(joinids)
n_col = len(table["soma_dim_1"].unique())

return self._convert_to_ndarray(axis, table, n_row, n_col)

@property
def _obs_df(self) -> DataFrame:
return self.experiment.obs
Expand Down Expand Up @@ -995,6 +943,22 @@ def load_joinids(df: DataFrame, axq: AxisQuery) -> pa.IntegerArray:
return tbl.column("soma_joinid").combine_chunks()


def _read_inner_ndarray(
matrix: SparseNDArray,
joinids: pa.IntegerArray,
indexer: Callable[[Numpyable], npt.NDArray[np.intp]],
) -> npt.NDArray[np.float32]:
table = matrix.read((joinids, slice(None))).tables().concat()

n_row = len(joinids)
n_col = len(table["soma_dim_1"].unique())

idx = indexer(table["soma_dim_0"])
z: npt.NDArray[np.float32] = np.zeros(n_row * n_col, dtype=np.float32)
np.put(z, idx * n_col + table["soma_dim_1"], table["soma_data"])
return z.reshape(n_row, n_col)


def _read_as_csr(
matrix: SparseNDArray,
d0_joinids_arr: pa.IntegerArray,
Expand Down
14 changes: 0 additions & 14 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from tiledbsoma._collection import CollectionBase
from tiledbsoma._experiment import Experiment
from tiledbsoma._query import AxisName
from tiledbsoma.experiment_query import X_as_series

from tests._util import raises_no_typeguard
Expand Down Expand Up @@ -965,16 +964,3 @@ class IHaveObsVarStuff:
var: int
the_obs_suf: str
the_var_suf: str


def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == AxisName.OBS.getattr_from(thing)
assert 2 == AxisName.VAR.getattr_from(thing)
assert "observe" == AxisName.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == AxisName.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == AxisName.OBS.getitem_from(ovdict)
assert "y" == AxisName.VAR.getitem_from(ovdict)
assert "hide" == AxisName.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == AxisName.VAR.getitem_from(ovdict, pre="i_", suf="cure")

0 comments on commit 7987429

Please sign in to comment.