Skip to content

Commit

Permalink
Fix presence matrix calculation (and derived values) (#1320)
Browse files Browse the repository at this point in the history
* start

* Ensure some zero values in test data for builder

* Fix n_measured_var

* Fix calculation of n_measured_obs

* More helpful comment

* Minor cleanup

* Add validation check to compare presence matrix with original h5ad

* Simplify data passed to create presence matrix

* Remove unused PresenceResult attributes and types

* Fix typing

* Fix validation for presence matrix

* Fix import from earlier anndata
  • Loading branch information
ivirshup authored Jan 30, 2025
1 parent fbb00fe commit 50fea93
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@

@attrs.define
class PresenceResult:
dataset_id: str
dataset_soma_joinid: int
eb_name: str
data: npt.NDArray[np.bool_]
cols: npt.NDArray[np.int64]


Expand All @@ -67,10 +65,6 @@ class AxisStats:
var_stats: pd.DataFrame


AccumulateXResult = tuple[PresenceResult, AxisStats]
AccumulateXResults = Sequence[AccumulateXResult]


def _assert_open_for_write(obj: somacore.SOMAObject | None) -> None:
assert obj is not None
assert obj.exists(obj.uri)
Expand Down Expand Up @@ -132,7 +126,7 @@ def __init__(self, specification: ExperimentSpecification):
self.experiment: soma.Experiment | None = None # initialized in create()
self.experiment_uri: str | None = None # initialized in create()
self.global_var_joinids: pd.DataFrame | None = None
self.presence: dict[int, tuple[npt.NDArray[np.bool_], npt.NDArray[np.int64]]] = {}
self.presence: dict[int, npt.NDArray[np.int64]] = {}

@property
def name(self) -> str:
Expand Down Expand Up @@ -242,9 +236,8 @@ def populate_presence_matrix(self, datasets: list[Dataset]) -> None:

# LIL is fast way to create spmatrix
pm = sparse.lil_matrix((max_dataset_joinid + 1, self.n_var), dtype=bool)
for dataset_joinid, presence in self.presence.items():
data, cols = presence
pm[dataset_joinid, cols] = data
for dataset_joinid, cols in self.presence.items():
pm[dataset_joinid, cols] = 1

pm = pm.tocoo()
pm.eliminate_zeros()
Expand Down Expand Up @@ -457,14 +450,12 @@ def compute_X_file_stats(

obs_stats = res["obs_stats"]
var_stats = res["var_stats"]
obs_stats["n_measured_vars"] = (var_stats.nnz > 0).sum()
var_stats.loc[var_stats.nnz > 0, "n_measured_obs"] = n_obs
obs_stats["n_measured_vars"] = var_stats.shape[0]
var_stats["n_measured_obs"] = n_obs
res["presence"].append(
PresenceResult(
dataset_id,
dataset_soma_joinid,
eb_name,
(var_stats.nnz > 0).to_numpy(),
var_stats.index.to_numpy(),
),
)
Expand Down Expand Up @@ -713,10 +704,7 @@ def populate_X_layers(

for presence in eb_summary["presence"]:
assert presence.eb_name == eb.name
eb.presence[presence.dataset_soma_joinid] = (
presence.data,
presence.cols,
)
eb.presence[presence.dataset_soma_joinid] = presence.cols


class SummaryStats(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_obs_stats(
"raw_variance_nnz": raw_variance_nnz.astype(
CENSUS_OBS_TABLE_SPEC.field("raw_variance_nnz").to_pandas_dtype()
),
"n_measured_vars": -1, # placeholder
"n_measured_vars": -1, # handled on dataset level in compute_X_file_stats
}
)
assert len(obs_stats) == raw_X.shape[0]
Expand All @@ -53,7 +53,7 @@ def get_var_stats(
var_stats = pd.DataFrame(
data={
"nnz": nnz.astype(CENSUS_VAR_TABLE_SPEC.field("nnz").to_pandas_dtype()),
"n_measured_obs": 0, # placeholder
"n_measured_obs": 0, # handled on dataset level in compute_X_file_stats
}
)
assert len(var_stats) == raw_X.shape[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Self, TypeVar
from typing import Any, Self, TypeVar, cast

import dask
import numpy as np
Expand Down Expand Up @@ -533,7 +533,7 @@ def _validate_X_layers_has_unique_coords(


def validate_X_layers_presence(
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification]
soma_path: str, datasets: list[Dataset], experiment_specifications: list[ExperimentSpecification], assets_path: str
) -> Delayed[bool]:
"""Validate that the presence matrix accurately summarizes X[raw] for each experiment.
Expand All @@ -543,6 +543,15 @@ def validate_X_layers_presence(
3. Presence mask per dataset is correct for each dataset
"""

def _read_var_names(path: str) -> npt.NDArray[np.object_]:
import h5py
from anndata.experimental import read_elem

with h5py.File(path) as f:
index_key = f["var"].attrs["_index"]
var_names = read_elem(f["var"][index_key])
return cast(npt.NDArray[np.object_], var_names)

@logit(logger)
def _validate_X_layers_presence_general(experiment_specifications: list[ExperimentSpecification]) -> bool:
for es in experiment_specifications:
Expand Down Expand Up @@ -570,29 +579,29 @@ def _validate_X_layers_presence_general(experiment_specifications: list[Experime

@logit(logger, msg="{0.dataset_id}")
def _validate_X_layers_presence(
dataset: Dataset, experiment_specifications: list[ExperimentSpecification], soma_path: str
dataset: Dataset,
experiment_specifications: list[ExperimentSpecification],
soma_path: str,
assets_path: str,
) -> bool:
"""For a given dataset and experiment, confirm that the presence matrix matches contents of X[raw]."""
for es in experiment_specifications:
with open_experiment(soma_path, es) as exp:
obs_df = (
exp.obs.read(
value_filter=f"dataset_id == '{dataset.soma_joinid}'",
value_filter=f"dataset_id == '{dataset.dataset_id}'",
column_names=["soma_joinid", "n_measured_vars"],
)
.concat()
.to_pandas()
)
if len(obs_df) > 0: # skip empty experiments
X_raw = exp.ms[MEASUREMENT_RNA_NAME].X["raw"]

presence_accumulator = np.zeros((X_raw.shape[1]), dtype=np.bool_)
for block, _ in (
X_raw.read(coords=(obs_df.soma_joinids.to_numpy(), slice(None)))
.blockwise(axis=0, size=2**20, eager=False, reindex_disable_on_axis=[0, 1])
.tables()
):
presence_accumulator[block["soma_dim_1"].to_numpy()] = 1
feature_ids = pd.Index(
exp.ms[MEASUREMENT_RNA_NAME]
.var.read(column_names=["feature_id"])
.concat()
.to_pandas()["feature_id"]
)

presence = (
exp.ms[MEASUREMENT_RNA_NAME][FEATURE_DATASET_PRESENCE_MATRIX_NAME]
Expand All @@ -601,17 +610,22 @@ def _validate_X_layers_presence(
.concat()
)

assert np.array_equal(presence_accumulator, presence), "Presence value does not match X[raw]"
# Get soma_joinids for feature in the original h5ad
orig_feature_ids = _read_var_names(f"{assets_path}/{dataset.dataset_h5ad_path}")
orig_indices = np.sort(feature_ids.get_indexer(feature_ids.intersection(orig_feature_ids)))

assert (
obs_df.n_measured_vars.to_numpy() == presence_accumulator.sum()
).all(), f"{es.name}:{dataset.dataset_id} obs.n_measured_vars incorrect."
np.testing.assert_array_equal(presence["soma_dim_1"], orig_indices)

return True

check_presence_values = (
dask.bag.from_sequence(datasets, partition_size=8)
.map(_validate_X_layers_presence, soma_path=soma_path, experiment_specifications=experiment_specifications)
.map(
_validate_X_layers_presence,
soma_path=soma_path,
experiment_specifications=experiment_specifications,
assets_path=assets_path,
)
.reduction(all, all)
.to_delayed()
)
Expand Down Expand Up @@ -968,9 +982,14 @@ def validate_internal_consistency(
"""
datasets_df["presence_sum_var_axis"] = presence.sum(axis=1).A1
tmp = obs.merge(datasets_df, left_on="dataset_id", right_on="dataset_id")
assert (
tmp.n_measured_vars == tmp.presence_sum_var_axis
).all(), f"{eb.name}: obs.n_measured_vars does not match presence matrix."
try:
np.testing.assert_array_equal(
tmp["n_measured_vars"],
tmp["presence_sum_var_axis"],
)
except AssertionError as e:
e.add_note(f"{eb.name}: obs.n_measured_vars does not match presence matrix.")
raise
del tmp

# Assertion 3 - var.n_measured_obs is consistent with presence matrix
Expand Down Expand Up @@ -1091,7 +1110,7 @@ def validate_soma(args: CensusBuildArgs, client: dask.distributed.Client) -> das
dask.delayed(validate_X_layers_schema)(soma_path, experiment_specifications, eb_info),
validate_X_layers_normalized(soma_path, experiment_specifications),
validate_X_layers_has_unique_coords(soma_path, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications),
validate_X_layers_presence(soma_path, datasets, experiment_specifications, assets_path),
)
)
],
Expand Down
9 changes: 5 additions & 4 deletions tools/cellxgene_census_builder/tests/anndata/test_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,31 @@
from ..conftest import GENE_IDS, ORGANISMS, get_anndata


def test_open_anndata(datasets: list[Dataset]) -> None:
def test_open_anndata(datasets: list[Dataset], census_build_args: CensusBuildArgs) -> None:
"""`open_anndata` should open the h5ads for each of the dataset in the argument,
and yield both the dataset and the corresponding AnnData object.
This test does not involve additional filtering steps.
The `datasets` used here have no raw layer.
"""
assets_path = census_build_args.h5ads_path.as_posix()

def _todense(X: npt.NDArray[np.float32] | sparse.spmatrix) -> npt.NDArray[np.float32]:
if isinstance(X, np.ndarray):
return X
else:
return cast(npt.NDArray[np.float32], X.todense())

result = [(d, open_anndata(d, base_path=".")) for d in datasets]
result = [(d, open_anndata(d, base_path=assets_path)) for d in datasets]
assert len(result) == len(datasets) and len(datasets) > 0
for i, (dataset, anndata_obj) in enumerate(result):
assert dataset == datasets[i]
opened_anndata = anndata.read_h5ad(dataset.dataset_h5ad_path)
opened_anndata = anndata.read_h5ad(f"{assets_path}/{dataset.dataset_h5ad_path}")
assert opened_anndata.obs.equals(anndata_obj.obs)
assert opened_anndata.var.equals(anndata_obj.var)
assert np.array_equal(_todense(opened_anndata.X), _todense(anndata_obj.X))

# also check context manager
with open_anndata(datasets[0], base_path=".") as ad:
with open_anndata(datasets[0], base_path=assets_path) as ad:
assert ad.n_obs == len(ad.obs)


Expand Down
22 changes: 16 additions & 6 deletions tools/cellxgene_census_builder/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pathlib
from functools import partial
from typing import Literal

import anndata
Expand Down Expand Up @@ -43,8 +44,17 @@ def get_anndata(
n_cells = 4
n_genes = len(gene_ids)
rng = np.random.default_rng()
min_X_val = 1 if no_zero_counts else 0
X = rng.integers(min_X_val, min_X_val + 5, size=(n_cells, n_genes)).astype(np.float32)
if no_zero_counts:
X = rng.integers(1, 6, size=(n_cells, n_genes)).astype(np.float32)
else:
X = sparse.random(
n_cells,
n_genes,
density=0.5,
random_state=rng,
data_rvs=partial(rng.integers, 1, 6),
dtype=np.float32,
).toarray()

# Builder code currently assumes (and enforces) that ALL cells (rows) contain at least
# one non-zero value in their count matrix. Enforce this assumption, as the rng will
Expand Down Expand Up @@ -148,10 +158,10 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
for organism in ORGANISMS:
for i in range(NUM_DATASET):
h5ad = get_anndata(
organism, GENE_IDS[i], no_zero_counts=True, assay_ontology_term_id=ASSAY_IDS[i], X_format=X_FORMAT[i]
organism, GENE_IDS[i], no_zero_counts=False, assay_ontology_term_id=ASSAY_IDS[i], X_format=X_FORMAT[i]
)
h5ad_path = f"{assets_path}/{organism.name}_{i}.h5ad"
h5ad.write_h5ad(h5ad_path)
h5ad_name = f"{organism.name}_{i}.h5ad"
h5ad.write_h5ad(f"{assets_path}/{h5ad_name}")
datasets.append(
Dataset(
dataset_id=f"{organism.name}_{i}",
Expand All @@ -160,7 +170,7 @@ def datasets(census_build_args: CensusBuildArgs) -> list[Dataset]:
collection_id=f"id_{organism.name}",
collection_name=f"collection_{organism.name}",
dataset_asset_h5ad_uri="mock",
dataset_h5ad_path=h5ad_path,
dataset_h5ad_path=h5ad_name,
dataset_version_id=f"{organism.name}_{i}_v0",
),
)
Expand Down

0 comments on commit 50fea93

Please sign in to comment.