diff --git a/setup.cfg b/setup.cfg index fa0c58c..95a8968 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,7 @@ testing = celldex scrnaseq scipy + scranpy [options.entry_points] # Add here console scripts like: diff --git a/src/singler/train_single.py b/src/singler/train_single.py index 35e19d9..2d65264 100644 --- a/src/singler/train_single.py +++ b/src/singler/train_single.py @@ -101,7 +101,8 @@ def train_single( restrict_to: Optional[Union[set, dict]] = None, check_missing: bool = True, markers: Optional[dict[Any, dict[Any, Sequence]]] = None, - marker_method: Literal["classic"] = "classic", + marker_method: Literal["classic", "auc", "cohens_d"] = "classic", + num_de: Optional[int] = None, marker_args: dict = {}, nn_parameters: Optional[knncolle.Parameters] = knncolle.VptreeParameters(), num_threads: int = 1, @@ -152,13 +153,20 @@ def train_single( should have keys in the inner and outer dictionaries. marker_method: - Method to identify markers from each pairwise comparisons between - labels in ``ref_data``. If "classic", we call - :py:meth:`~singler.get_classic_markers.get_classic_markers`. + Method to identify markers from each pairwise comparisons between labels in ``ref_data``. + If ``classic``, we call :py:func:`~singler.get_classic_markers.get_classic_markers`. + If ``auc`` or ``cohens_d``, we call :py:func:`~scranpy.score_markers.score_markers`. + Only used if ``markers`` is not supplied. + + num_de: + Number of differentially expressed genes to use as markers for each pairwise comparison between labels. + If ``None`` and ``marker_method = "classic"``, an appropriate number of genes is determined by :py:func:`~singler.get_classic_markers.get_classic_markers`. + Otherwise, it is set to 10. Only used if ``markers`` is not supplied. marker_args: Further arguments to pass to the chosen marker detection method. + If ``marker_method = "classic"``, this is :py:func:`~singler.get_classic_markers.get_classic_markers`, otherwise it is :py:func:`~scranpy.score_markers.score_markers`. Only used if ``markers`` is not supplied. nn_parameters: @@ -214,6 +222,7 @@ def train_single( unique_labels=unique_labels, markers=markers, marker_method=marker_method, + num_de=num_de, test_features=test_features, restrict_to=restrict_to, marker_args=marker_args, @@ -258,7 +267,7 @@ def train_single( ) -def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers, marker_method, test_features, restrict_to, marker_args, num_threads): +def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers, marker_method, test_features, restrict_to, num_de, marker_args, num_threads): ref_data, ref_features = _restrict_features(ref_data, ref_features, test_features) ref_data, ref_features = _restrict_features(ref_data, ref_features, restrict_to) @@ -270,11 +279,52 @@ def _identify_genes(ref_data, ref_features, ref_labels, unique_labels, markers, ref_labels=[ref_labels], ref_features=[ref_features], num_threads=num_threads, + num_de=num_de, **marker_args, ) else: - raise NotImplementedError("other marker methods are not yet implemented, sorry") - return markers + if marker_method == "auc": + compute_auc = True + compute_cohens_d = False + effect_size = "auc" + boundary = 0.5 + else: + compute_auc = False + compute_cohens_d = True + effect_size = "cohens_d" + boundary = 0 + + import scranpy + stats = scranpy.score_markers( + ref_data, + groups=ref_labels, + num_threads=num_threads, + all_pairwise=True, + compute_delta_detected=False, + compute_delta_mean=False, + compute_auc=compute_auc, + compute_cohens_d=compute_cohens_d, + **marker_args + ) + pairwise = getattr(stats, effect_size) + + if num_de is None: + num_de = 10 + + markers = {} + for g1, group1 in enumerate(stats.groups): + group_markers = {} + for g2, group2 in enumerate(stats.groups): + if g1 == g2: + group_markers[group2] = biocutils.StringList() + continue + cureffects = pairwise[g2, g1, :] # remember, second dimension is the first group in the comparison. + keep = biocutils.which(cureffects > boundary) + o = numpy.argsort(-cureffects[keep], stable=True) + group_markers[group2] = biocutils.StringList(biocutils.subset_sequence(ref_features, keep[o[:min(len(o), num_de)]])) + markers[group1] = group_markers + + return markers # Validating a user-supplied list of markers. if not isinstance(markers, dict): diff --git a/tests/test_train_single.py b/tests/test_train_single.py index 596a6c2..0113486 100644 --- a/tests/test_train_single.py +++ b/tests/test_train_single.py @@ -1,5 +1,6 @@ import singler import numpy +import biocutils def test_train_single_basic(): @@ -82,3 +83,66 @@ def test_train_single_restricted(): expected_output = singler.classify_single(test[keep,:], expected) assert (output.column("delta") == expected_output.column("delta")).all() assert output.column("best") == expected_output.column("best") + + +def test_train_single_scranpy(): + ref = numpy.random.rand(10000, 10) + labels = ["A", "B", "C", "D", "A", "B", "C", "A", "B", "A"] + features = ["gene_" + str(i) for i in range(ref.shape[0])] + + import scranpy + effects = scranpy.score_markers(ref, labels, all_pairwise=True) + + def verify(ref_markers, effect_sizes, hard_limit, extra): + all_labels = sorted(list(ref_markers.keys())) + assert all_labels == sorted(effects.groups) + + for g1, group1 in enumerate(effects.groups): + current_markers = ref_markers[group1] + assert all_labels == sorted(list(current_markers.keys())) + + for g2, group2 in enumerate(effects.groups): + if g1 == g2: + assert len(current_markers[group2]) == 0 + else: + my_effects = effect_sizes[g2, g1, :] + assert len(my_effects) == 10000 + my_markers = current_markers[group2] + assert len(my_markers) > 0 + my_markers_set = set(my_markers) + is_chosen = numpy.array([f in my_markers_set for f in features]) + min_chosen = my_effects[is_chosen].min() + assert min_chosen >= my_effects[numpy.logical_not(is_chosen)].max() + assert min_chosen > hard_limit + if extra is not None: + extra(group1, group2, my_markers) + + built = singler.train_single(ref, labels, features, marker_method="auc") + verify(built.markers, effects.auc, 0.5, extra=None) + + built = singler.train_single(ref, labels, features, marker_method="cohens_d") + def extra_cohen(n, n2, my_markers): + assert len(my_markers) <= 10 + markerref = ref[biocutils.match(my_markers, features),:] + left = markerref[:,[n == l for l in labels]].mean(axis=1) + right = markerref[:,[n2 == l for l in labels]].mean(axis=1) + assert (left > right).all() + verify(built.markers, effects.cohens_d, 0, extra=extra_cohen) + + built = singler.train_single(ref, labels, features, marker_method="cohens_d", num_de=10000) + def extra_cohen(n, n2, my_markers): + assert len(my_markers) > 10 + markerref = ref[biocutils.match(my_markers, features),:] + left = markerref[:,[n == l for l in labels]].mean(axis=1) + right = markerref[:,[n2 == l for l in labels]].mean(axis=1) + assert (left > right).all() + verify(built.markers, effects.cohens_d, 0, extra=extra_cohen) + + # Responds to threshold specification. + thresh_effects = scranpy.score_markers(ref, labels, threshold=1, all_pairwise=True) + def extra_threshold(n, n2, my_markers): + markerref = ref[biocutils.match(my_markers, features),:] + left = markerref[:,[n == l for l in labels]].mean(axis=1) + right = markerref[:,[n2 == l for l in labels]].mean(axis=1) + assert (left > right + 1).all() + verify(built.markers, effects.cohens_d, 0, extra=extra_cohen)