Skip to content

Commit

Permalink
Support scran's marker detection methods in trainSingle.
Browse files Browse the repository at this point in the history
This brings us in line with the functionality in the R package.
  • Loading branch information
LTLA committed Jan 7, 2025
1 parent 5f2e9c1 commit 54e73d4
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 7 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ testing =
celldex
scrnaseq
scipy
scranpy

[options.entry_points]
# Add here console scripts like:
Expand Down
64 changes: 57 additions & 7 deletions src/singler/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def train_single(
restrict_to: Optional[Union[set, dict]] = None,
check_missing: bool = True,
markers: Optional[dict[Any, dict[Any, Sequence]]] = None,
marker_method: Literal["classic"] = "classic",
marker_method: Literal["classic", "auc", "cohens_d"] = "classic",
num_de: Optional[int] = None,
marker_args: dict = {},
nn_parameters: Optional[knncolle.Parameters] = knncolle.VptreeParameters(),
num_threads: int = 1,
Expand Down Expand Up @@ -152,13 +153,20 @@ def train_single(
should have keys in the inner and outer dictionaries.
marker_method:
Method to identify markers from each pairwise comparisons between
labels in ``ref_data``. If "classic", we call
:py:meth:`~singler.get_classic_markers.get_classic_markers`.
Method to identify markers from each pairwise comparisons between labels in ``ref_data``.
If ``classic``, we call :py:func:`~singler.get_classic_markers.get_classic_markers`.
If ``auc`` or ``cohens_d``, we call :py:func:`~scranpy.score_markers.score_markers`.
Only used if ``markers`` is not supplied.
num_de:
Number of differentially expressed genes to use as markers for each pairwise comparison between labels.
If ``None`` and ``marker_method = "classic"``, an appropriate number of genes is determined by :py:func:`~singler.get_classic_markers.get_classic_markers`.
Otherwise, it is set to 10.
Only used if ``markers`` is not supplied.
marker_args:
Further arguments to pass to the chosen marker detection method.
If ``marker_method = "classic"``, this is :py:func:`~singler.get_classic_markers.get_classic_markers`, otherwise it is :py:func:`~scranpy.score_markers.score_markers`.
Only used if ``markers`` is not supplied.
nn_parameters:
Expand Down Expand Up @@ -214,6 +222,7 @@ def train_single(
unique_labels=unique_labels,
markers=markers,
marker_method=marker_method,
num_de=num_de,
test_features=test_features,
restrict_to=restrict_to,
marker_args=marker_args,
Expand Down Expand Up @@ -258,7 +267,7 @@ def train_single(
)


def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers, marker_method, test_features, restrict_to, marker_args, num_threads):
def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers, marker_method, test_features, restrict_to, num_de, marker_args, num_threads):
ref_data, ref_features = _restrict_features(ref_data, ref_features, test_features)
ref_data, ref_features = _restrict_features(ref_data, ref_features, restrict_to)

Expand All @@ -270,11 +279,52 @@ def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers,
ref_labels=[ref_labels],
ref_features=[ref_features],
num_threads=num_threads,
num_de=num_de,
**marker_args,
)
else:
raise NotImplementedError("other marker methods are not yet implemented, sorry")
return markers
if marker_method == "auc":
compute_auc = True
compute_cohens_d = False
effect_size = "auc"
boundary = 0.5
else:
compute_auc = False
compute_cohens_d = True
effect_size = "cohens_d"
boundary = 0

import scranpy
stats = scranpy.score_markers(
ref_data,
groups=ref_labels,
num_threads=num_threads,
all_pairwise=True,
compute_delta_detected=False,
compute_delta_mean=False,
compute_auc=compute_auc,
compute_cohens_d=compute_cohens_d,
**marker_args
)
pairwise = getattr(stats, effect_size)

if num_de is None:
num_de = 10

markers = {}
for g1, group1 in enumerate(stats.groups):
group_markers = {}
for g2, group2 in enumerate(stats.groups):
if g1 == g2:
group_markers[group2] = biocutils.StringList()
continue
cureffects = pairwise[g2, g1, :] # remember, second dimension is the first group in the comparison.
keep = biocutils.which(cureffects > boundary)
o = numpy.argsort(-cureffects[keep], stable=True)
group_markers[group2] = biocutils.StringList(biocutils.subset_sequence(ref_features, keep[o[:min(len(o), num_de)]]))
markers[group1] = group_markers

return markers

# Validating a user-supplied list of markers.
if not isinstance(markers, dict):
Expand Down
64 changes: 64 additions & 0 deletions tests/test_train_single.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import singler
import numpy
import biocutils


def test_train_single_basic():
Expand Down Expand Up @@ -82,3 +83,66 @@ def test_train_single_restricted():
expected_output = singler.classify_single(test[keep,:], expected)
assert (output.column("delta") == expected_output.column("delta")).all()
assert output.column("best") == expected_output.column("best")


def test_train_single_scranpy():
ref = numpy.random.rand(10000, 10)
labels = ["A", "B", "C", "D", "A", "B", "C", "A", "B", "A"]
features = ["gene_" + str(i) for i in range(ref.shape[0])]

import scranpy
effects = scranpy.score_markers(ref, labels, all_pairwise=True)

def verify(ref_markers, effect_sizes, hard_limit, extra):
all_labels = sorted(list(ref_markers.keys()))
assert all_labels == sorted(effects.groups)

for g1, group1 in enumerate(effects.groups):
current_markers = ref_markers[group1]
assert all_labels == sorted(list(current_markers.keys()))

for g2, group2 in enumerate(effects.groups):
if g1 == g2:
assert len(current_markers[group2]) == 0
else:
my_effects = effect_sizes[g2, g1, :]
assert len(my_effects) == 10000
my_markers = current_markers[group2]
assert len(my_markers) > 0
my_markers_set = set(my_markers)
is_chosen = numpy.array([f in my_markers_set for f in features])
min_chosen = my_effects[is_chosen].min()
assert min_chosen >= my_effects[numpy.logical_not(is_chosen)].max()
assert min_chosen > hard_limit
if extra is not None:
extra(group1, group2, my_markers)

built = singler.train_single(ref, labels, features, marker_method="auc")
verify(built.markers, effects.auc, 0.5, extra=None)

built = singler.train_single(ref, labels, features, marker_method="cohens_d")
def extra_cohen(n, n2, my_markers):
assert len(my_markers) <= 10
markerref = ref[biocutils.match(my_markers, features),:]
left = markerref[:,[n == l for l in labels]].mean(axis=1)
right = markerref[:,[n2 == l for l in labels]].mean(axis=1)
assert (left > right).all()
verify(built.markers, effects.cohens_d, 0, extra=extra_cohen)

built = singler.train_single(ref, labels, features, marker_method="cohens_d", num_de=10000)
def extra_cohen(n, n2, my_markers):
assert len(my_markers) > 10
markerref = ref[biocutils.match(my_markers, features),:]
left = markerref[:,[n == l for l in labels]].mean(axis=1)
right = markerref[:,[n2 == l for l in labels]].mean(axis=1)
assert (left > right).all()
verify(built.markers, effects.cohens_d, 0, extra=extra_cohen)

# Responds to threshold specification.
thresh_effects = scranpy.score_markers(ref, labels, threshold=1, all_pairwise=True)
def extra_threshold(n, n2, my_markers):
markerref = ref[biocutils.match(my_markers, features),:]
left = markerref[:,[n == l for l in labels]].mean(axis=1)
right = markerref[:,[n2 == l for l in labels]].mean(axis=1)
assert (left > right + 1).all()
verify(built.markers, effects.cohens_d, 0, extra=extra_cohen)

0 comments on commit 54e73d4

Please sign in to comment.