From 0c0476df8af810cecaf24143813a01600c574239 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 9 Oct 2025 11:11:01 +0200 Subject: [PATCH 01/25] add lognormalization to celltypist component --- src/annotate/celltypist/config.vsh.yaml | 6 +- src/annotate/celltypist/script.py | 92 ++++++++++++++----------- src/annotate/celltypist/test.py | 54 --------------- 3 files changed, 54 insertions(+), 98 deletions(-) diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index ccb8d9fc17c..fe5a8ed3460 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -26,7 +26,7 @@ argument_groups: required: false - name: "--input_layer" type: string - description: The layer in the input data containing log normalized counts to be used for cell type annotation if .X is not to be used. + description: The layer in the input data containing raw counts to be used for cell type annotation if .X is not to be used. - name: "--input_var_gene_names" type: string required: false @@ -50,7 +50,7 @@ argument_groups: required: false - name: "--reference_layer" type: string - description: The layer in the reference data to be used for cell type annotation if .X is not to be used. Data are expected to be processed in the same way as the --input query dataset. + description: The layer in the reference data containing raw counts to be used for cell type annotation if .X is not to be used. required: false - name: "--reference_obs_target" type: string @@ -152,7 +152,7 @@ engines: packages: - celltypist==1.6.3 - type: python - __merge__: [ /src/base/requirements/anndata_mudata.yaml, .] + __merge__: [ /src/base/requirements/anndata_mudata.yaml, /src/base/requirements/scanpy.yaml, .] __merge__: [ /src/base/requirements/python_test_setup.yaml, .] runners: - type: executable diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index e3efb749d3d..f9f61d59b9f 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -3,23 +3,20 @@ import mudata as mu import anndata as ad import pandas as pd -import numpy as np +import scanpy as sc ## VIASH START par = { "input": "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu", "output": "output.h5mu", "modality": "rna", - # "reference": None, "reference": "resources_test/annotation_test_data/TS_Blood_filtered.h5mu", "model": None, - # "model": "resources_test/annotation_test_data/celltypist_model_Immune_All_Low.pkl", "input_layer": "log_normalized", "reference_layer": "log_normalized", "input_reference_gene_overlap": 100, "reference_obs_target": "cell_ontology_class", "reference_var_input": None, - "check_expression": False, "feature_selection": True, "majority_voting": True, "output_compression": "gzip", @@ -44,10 +41,43 @@ logger = setup_logger() -def check_celltypist_format(indata): - if np.abs(np.expm1(indata[0]).sum() - 10000) > 1: - return False - return True +def setup_anndata( + adata: ad.AnnData, + layer: str | None = None, + gene_names: str | None = None, + var_input: str | None = None, +) -> ad.AnnData: + """Creates an AnnData object in the expected format for CellTypist, + with lognormalized data (with a target sum of 10000) in the .X slot. + + Parameters + ---------- + adata + AnnData object. + layer + Layer in AnnData object to lognormalize. + gene_names + .obs field with the gene names to be used + var_input + .var field with a boolean array of the genes to be used (e.g. highly variable genes) + Returns + ------- + AnnData object in CellTypist format. + """ + + adata = set_var_index(adata, gene_names) + + if var_input: + adata = subset_vars(adata, var_input) + + raw_counts = adata.layers[layer].copy() if layer else adata.X.copy() + + input_modality = ad.AnnData(X=raw_counts, var=pd.DataFrame(index=adata.var.index)) + + sc.pp.normalize_total(input_modality, target_sum=10000) + sc.pp.log1p(input_modality) + + return input_modality def main(par): @@ -63,17 +93,8 @@ def main(par): input_modality = input_adata.copy() # Provide correct format of query data for celltypist annotation - ## Sanitize gene names and set as index - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) - ## Fetch lognormalized counts - lognorm_counts = ( - input_modality.layers[par["input_layer"]].copy() - if par["input_layer"] - else input_modality.X.copy() - ) - ## Create AnnData object - input_modality = ad.AnnData( - X=lognorm_counts, var=pd.DataFrame(index=input_modality.var.index) + input_modality = setup_anndata( + input_modality, par["input_layer"], par["input_var_gene_names"] ) if par["model"]: @@ -86,18 +107,15 @@ def main(par): ) elif par["reference"]: - reference_modality = mu.read_h5mu(par["reference"]).mod[par["modality"]] - - # subset to HVG if required - if par["reference_var_input"]: - reference_modality = subset_vars( - reference_modality, par["reference_var_input"] - ) - - # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) - # CellTypist requires query gene names to be in index - reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_adata = mu.read_h5mu(par["reference"]).mod[par["modality"]] + reference_modality = reference_adata.copy() + + # Provide correct format of query data for celltypist annotation + reference_modality = setup_anndata( + reference_modality, + par["reference_layer"], + par["reference_var_gene_names"], + par["reference_var_input"], ) # Ensure enough overlap between genes in query and reference @@ -107,18 +125,10 @@ def main(par): min_gene_overlap=par["input_reference_gene_overlap"], ) - reference_matrix = ( - reference_modality.layers[par["reference_layer"]] - if par["reference_layer"] - else reference_modality.X - ) - - labels = reference_modality.obs[par["reference_obs_target"]] - logger.info("Training CellTypist model on reference") model = celltypist.train( - reference_matrix, - labels=labels, + reference_modality.X, + labels=reference_adata.obs[par["reference_obs_target"]], genes=reference_modality.var.index, C=par["C"], max_iter=par["max_iter"], diff --git a/src/annotate/celltypist/test.py b/src/annotate/celltypist/test.py index 60704b3c24c..52670523a91 100644 --- a/src/annotate/celltypist/test.py +++ b/src/annotate/celltypist/test.py @@ -1,8 +1,6 @@ import sys import os import pytest -import subprocess -import re import mudata as mu from openpipeline_testutils.asserters import assert_annotation_objects_equal @@ -27,12 +25,8 @@ def test_simple_execution(run_component, random_h5mu_path): [ "--input", input_file, - "--input_layer", - "log_normalized", "--reference", reference_file, - "--reference_layer", - "log_normalized", "--reference_obs_target", "cell_ontology_class", "--reference_var_gene_names", @@ -75,12 +69,8 @@ def test_set_params(run_component, random_h5mu_path): [ "--input", input_file, - "--input_layer", - "log_normalized", "--reference", reference_file, - "--reference_layer", - "log_normalized", "--reference_obs_target", "cell_ontology_class", "--reference_var_gene_names", @@ -159,49 +149,5 @@ def test_with_model(run_component, random_h5mu_path): ) -def test_fail_invalid_input_expression(run_component, random_h5mu_path): - output_file = random_h5mu_path() - - # fails because input data are not lognormalized - with pytest.raises(subprocess.CalledProcessError) as err: - run_component( - [ - "--input", - input_file, - "--reference", - reference_file, - "--reference_var_gene_names", - "ensemblid", - "--output", - output_file, - ] - ) - assert re.search( - r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell", - err.value.stdout.decode("utf-8"), - ) - - # fails because reference data are not lognormalized - with pytest.raises(subprocess.CalledProcessError) as err: - run_component( - [ - "--input", - input_file, - "--layer", - "log_normalized", - "--reference", - reference_file, - "--reference_var_gene_names", - "ensemblid", - "--output", - output_file, - ] - ) - assert re.search( - r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell", - err.value.stdout.decode("utf-8"), - ) - - if __name__ == "__main__": sys.exit(pytest.main([__file__])) From c853100941602d517133a569640a19fd5e39e047 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 9 Oct 2025 11:15:24 +0200 Subject: [PATCH 02/25] update changelog --- CHANGELOG.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 233c6c955a6..920776c9bc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,9 @@ ## BREAKING -* `differential_expression/create_pseudobulks`: Removed functionality to filter psuedobulk samples based on number of aggregated samples threshold, as this functionality is now covered in `filter/delimit_count` (PR #1044). +* `differential_expression/create_pseudobulks`: Removed functionality to filter pseudobulk samples based on number of aggregated samples threshold, as this functionality is now covered in `filter/delimit_count` (PR #1044). + +* `annotate/celtypist`: This component now requires to pass a raw count layer, that will be lognormalized with a target sum of 10000, the required count format for CellTypist (PR #1083). ## NEW FUNCTIONALITY From fac8a5b80dcb828d69b8fc013dc38d440b59c694 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 9 Oct 2025 16:44:42 +0200 Subject: [PATCH 03/25] make gene name sanitation optional --- src/annotate/celltypist/config.vsh.yaml | 6 ++- src/annotate/celltypist/script.py | 8 +++- src/annotate/onclass/config.vsh.yaml | 4 ++ src/annotate/onclass/script.py | 8 +++- .../random_forest_annotation/config.vsh.yaml | 4 ++ .../random_forest_annotation/script.py | 8 +++- src/annotate/scanvi/config.vsh.yaml | 6 ++- src/annotate/scanvi/script.py | 4 +- src/annotate/svm_annotation/config.vsh.yaml | 4 ++ src/annotate/svm_annotation/script.py | 8 +++- src/integrate/scarches/config.vsh.yaml | 4 ++ src/integrate/scarches/script.py | 4 +- src/integrate/scvi/config.vsh.yaml | 4 ++ src/integrate/scvi/script.py | 4 +- src/utils/set_var_index.py | 37 +++++++++++++++---- .../scanvi_scarches/config.vsh.yaml | 4 ++ .../annotation/scanvi_scarches/main.nf | 5 ++- .../annotation/scvi_knn/config.vsh.yaml | 4 ++ src/workflows/annotation/scvi_knn/main.nf | 3 +- .../integration/scvi_leiden/config.vsh.yaml | 5 +++ src/workflows/integration/scvi_leiden/main.nf | 1 + 21 files changed, 113 insertions(+), 22 deletions(-) diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index ccb8d9fc17c..7a0a9b5caa9 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -38,7 +38,11 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true + - name: Reference description: Arguments related to the reference dataset. arguments: diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index e3efb749d3d..ba64708dc36 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -64,7 +64,9 @@ def main(par): # Provide correct format of query data for celltypist annotation ## Sanitize gene names and set as index - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + input_modality = set_var_index( + input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + ) ## Fetch lognormalized counts lognorm_counts = ( input_modality.layers[par["input_layer"]].copy() @@ -97,7 +99,9 @@ def main(par): # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) # CellTypist requires query gene names to be in index reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_modality, + par["reference_var_gene_names"], + par["sanitize_gene_names"], ) # Ensure enough overlap between genes in query and reference diff --git a/src/annotate/onclass/config.vsh.yaml b/src/annotate/onclass/config.vsh.yaml index ff61b11cd5c..e20e9336bd2 100644 --- a/src/annotate/onclass/config.vsh.yaml +++ b/src/annotate/onclass/config.vsh.yaml @@ -39,6 +39,10 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Ontology description: Ontology input files diff --git a/src/annotate/onclass/script.py b/src/annotate/onclass/script.py index 0306c988f9b..a74f00a9760 100644 --- a/src/annotate/onclass/script.py +++ b/src/annotate/onclass/script.py @@ -123,7 +123,9 @@ def main(): input_modality = input_adata.copy() # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + input_modality = set_var_index( + input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + ) input_matrix = ( input_modality.layers[par["input_layer"]] if par["input_layer"] @@ -156,7 +158,9 @@ def main(): reference_mudata = mu.read_h5mu(par["reference"]) reference_modality = reference_mudata.mod[par["modality"]].copy() reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_modality, + par["reference_var_gene_names"], + par["sanitize_gene_names"], ) # subset to HVG if required diff --git a/src/annotate/random_forest_annotation/config.vsh.yaml b/src/annotate/random_forest_annotation/config.vsh.yaml index a8def936aad..7b2f3784acb 100644 --- a/src/annotate/random_forest_annotation/config.vsh.yaml +++ b/src/annotate/random_forest_annotation/config.vsh.yaml @@ -35,6 +35,10 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference description: Arguments related to the reference dataset. diff --git a/src/annotate/random_forest_annotation/script.py b/src/annotate/random_forest_annotation/script.py index df401315b62..5b36a88eede 100644 --- a/src/annotate/random_forest_annotation/script.py +++ b/src/annotate/random_forest_annotation/script.py @@ -47,7 +47,9 @@ def main(): input_mudata = mu.read_h5mu(par["input"]) input_adata = input_mudata.mod[par["modality"]] input_modality = input_adata.copy() - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + input_modality = set_var_index( + input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + ) # Handle max_features parameter max_features_conversion = { @@ -100,7 +102,9 @@ def main(): reference_mudata = mu.read_h5mu(par["reference"]) reference_modality = reference_mudata.mod[par["modality"]].copy() reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_modality, + par["reference_var_gene_names"], + par["sanitize_gene_names"], ) # subset to HVG if required diff --git a/src/annotate/scanvi/config.vsh.yaml b/src/annotate/scanvi/config.vsh.yaml index 7c1f26a5c6d..25d3598aa85 100644 --- a/src/annotate/scanvi/config.vsh.yaml +++ b/src/annotate/scanvi/config.vsh.yaml @@ -48,7 +48,11 @@ argument_groups: default: "Unknown" description: | Value in the --obs_labels field that indicates unlabeled observations - + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true + - name: scVI Model arguments: - name: "--scvi_model" diff --git a/src/annotate/scanvi/script.py b/src/annotate/scanvi/script.py index 6a89962c072..684cf7973b4 100644 --- a/src/annotate/scanvi/script.py +++ b/src/annotate/scanvi/script.py @@ -50,7 +50,9 @@ def main(): adata_subset = adata.copy() # Sanitize gene names and set as index of the AnnData object - adata_subset = set_var_index(adata_subset, par["var_gene_names"]) + adata_subset = set_var_index( + adata_subset, par["var_gene_names"], par["sanitize_gene_names"] + ) logger.info(f"Loading pre-trained scVI model from {par['scvi_model']}") scvi_model = scvi.model.SCVI.load( diff --git a/src/annotate/svm_annotation/config.vsh.yaml b/src/annotate/svm_annotation/config.vsh.yaml index 6d365fb41cb..5c5d67b4618 100644 --- a/src/annotate/svm_annotation/config.vsh.yaml +++ b/src/annotate/svm_annotation/config.vsh.yaml @@ -35,6 +35,10 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference description: Arguments related to the reference dataset. diff --git a/src/annotate/svm_annotation/script.py b/src/annotate/svm_annotation/script.py index ab4f8f69799..74ce56e6797 100644 --- a/src/annotate/svm_annotation/script.py +++ b/src/annotate/svm_annotation/script.py @@ -51,7 +51,9 @@ def main(): input_mudata = mu.read_h5mu(par["input"]) input_adata = input_mudata.mod[par["modality"]] input_modality = input_adata.copy() - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + input_modality = set_var_index( + input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + ) if par["model"]: logger.info("Loading a pre-trained model") @@ -82,7 +84,9 @@ def main(): reference_mudata = mu.read_h5mu(par["reference"]) reference_modality = reference_mudata.mod[par["modality"]].copy() reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_modality, + par["reference_var_gene_names"], + par["sanitize_gene_names"], ) # subset to HVG if required diff --git a/src/integrate/scarches/config.vsh.yaml b/src/integrate/scarches/config.vsh.yaml index 30dbe9f2cb6..cdc29dbceb0 100644 --- a/src/integrate/scarches/config.vsh.yaml +++ b/src/integrate/scarches/config.vsh.yaml @@ -66,6 +66,10 @@ argument_groups: (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for. Important: the order of the continuous covariates matters and should match the order of the covariates in the trained reference model. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference arguments: diff --git a/src/integrate/scarches/script.py b/src/integrate/scarches/script.py index f7d0c6a7fca..f16453dec04 100644 --- a/src/integrate/scarches/script.py +++ b/src/integrate/scarches/script.py @@ -133,7 +133,9 @@ def _align_query_with_registry(adata_query, model_path): # Sanitize gene names and set as index of the AnnData object # all scArches VAE models expect gene names to be in the .var index - adata_query = set_var_index(adata_query, par["input_var_gene_names"]) + adata_query = set_var_index( + adata_query, par["input_var_gene_names"], par["sanitize_gene_names"] + ) # align layer query_layer = ( diff --git a/src/integrate/scvi/config.vsh.yaml b/src/integrate/scvi/config.vsh.yaml index a4f566181d2..6b62d3b5244 100644 --- a/src/integrate/scvi/config.vsh.yaml +++ b/src/integrate/scvi/config.vsh.yaml @@ -73,6 +73,10 @@ argument_groups: addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Outputs arguments: - name: "--output" diff --git a/src/integrate/scvi/script.py b/src/integrate/scvi/script.py index bb88021b495..f82f060e3fc 100644 --- a/src/integrate/scvi/script.py +++ b/src/integrate/scvi/script.py @@ -84,7 +84,9 @@ def main(): adata_subset = adata.copy() # Sanitize gene names and set as index of the AnnData object - adata_subset = set_var_index(adata_subset, par["var_gene_names"]) + adata_subset = set_var_index( + adata_subset, par["var_gene_names"], par["sanitize_gene_names"] + ) check_validity_anndata( adata_subset, diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 3a1803ebf1d..65149675840 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -2,8 +2,26 @@ import re -def set_var_index(adata: ad.AnnData, var_name: str | None = None) -> ad.AnnData: - """Sanitize gene names and set the index of the .var DataFrame. +def sanitize_gene_names(gene_names: list[str]) -> list[str]: + """Sanitize gene names by removing version numbers. + + Parameters + ---------- + gene_names : list[str] + List of gene names to sanitize. + + Returns + ------- + list[str] + List of sanitized gene names. + """ + return [re.sub("\\.[0-9]+$", "", s) for s in gene_names] + + +def set_var_index( + adata: ad.AnnData, var_name: str | None = None, sanitise_gene_names: bool = True +) -> ad.AnnData: + """Sanitize gene names (optional) and set the index of the .var DataFrame. Parameters ---------- @@ -11,14 +29,19 @@ def set_var_index(adata: ad.AnnData, var_name: str | None = None) -> ad.AnnData: Annotated data object var_name : str | None Name of the column in `adata.var` that contains the gene names, if None, the existing index will be sanitized but not replaced. + sanitise_gene_names : bool + Whether to sanitize gene names by removing version numbers. Returns ------- AnnData - Copy of `adata` with sanitized and replaced index + Copy of `adata` with optionally sanitized and replaced index """ - if var_name: - adata.var.index = [re.sub("\\.[0-9]+$", "", s) for s in adata.var[var_name]] - else: - adata.var.index = [re.sub("\\.[0-9]+$", "", s) for s in adata.var.index] + gene_names = adata.var[var_name] if var_name else adata.var.index + + if sanitise_gene_names: + gene_names = sanitize_gene_names(gene_names) + + adata.var.index = gene_names + return adata diff --git a/src/workflows/annotation/scanvi_scarches/config.vsh.yaml b/src/workflows/annotation/scanvi_scarches/config.vsh.yaml index 09102e010c5..ef36879d02d 100644 --- a/src/workflows/annotation/scanvi_scarches/config.vsh.yaml +++ b/src/workflows/annotation/scanvi_scarches/config.vsh.yaml @@ -71,6 +71,10 @@ argument_groups: type: string required: false description: ".var column containing gene names. By default, use the index." + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference input arguments: diff --git a/src/workflows/annotation/scanvi_scarches/main.nf b/src/workflows/annotation/scanvi_scarches/main.nf index ba32fb7feb4..62738585da3 100644 --- a/src/workflows/annotation/scanvi_scarches/main.nf +++ b/src/workflows/annotation/scanvi_scarches/main.nf @@ -31,6 +31,7 @@ workflow run_wf { "reduce_lr_on_plateau": "reduce_lr_on_plateau", "lr_factor": "lr_factor", "lr_patience": "lr_patience", + "sanitize_gene_names": "sanitize_gene_names" ], args: [ "obsm_output": "X_integrated_scvi" @@ -63,6 +64,7 @@ workflow run_wf { "reduce_lr_on_plateau": "reduce_lr_on_plateau", "lr_factor": "lr_factor", "lr_patience": "lr_patience", + "sanitize_gene_names": "sanitize_gene_names" ], toState: [ "reference": "output", @@ -94,7 +96,8 @@ workflow run_wf { "lr_factor": "lr_factor", "lr_patience": "lr_patience", "output": "workflow_output", - "model_output": "workflow_output_model" + "model_output": "workflow_output_model", + "sanitize_gene_names": "sanitize_gene_names" ], toState: [ "input": "output", diff --git a/src/workflows/annotation/scvi_knn/config.vsh.yaml b/src/workflows/annotation/scvi_knn/config.vsh.yaml index b37003bf809..999068a840c 100644 --- a/src/workflows/annotation/scvi_knn/config.vsh.yaml +++ b/src/workflows/annotation/scvi_knn/config.vsh.yaml @@ -62,6 +62,10 @@ argument_groups: - name: "--overwrite_existing_key" type: boolean_true description: If provided, will overwrite existing fields in the input dataset when data are copied during the reference alignment process. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference input arguments: diff --git a/src/workflows/annotation/scvi_knn/main.nf b/src/workflows/annotation/scvi_knn/main.nf index 2a673f5652e..2eb6fade159 100644 --- a/src/workflows/annotation/scvi_knn/main.nf +++ b/src/workflows/annotation/scvi_knn/main.nf @@ -134,7 +134,8 @@ workflow run_wf { "max_epochs": state.scvi_max_epochs, "reduce_lr_on_plateau": state.scvi_reduce_lr_on_plateau, "lr_factor": state.scvi_lr_factor, - "lr_patience": state.scvi_lr_patience + "lr_patience": state.scvi_lr_patience, + "sanitize_gene_names": state.sanitize_gene_names ]}, args: [ "var_input": "_common_hvg", diff --git a/src/workflows/integration/scvi_leiden/config.vsh.yaml b/src/workflows/integration/scvi_leiden/config.vsh.yaml index ae564eeec87..bfab3b968d0 100644 --- a/src/workflows/integration/scvi_leiden/config.vsh.yaml +++ b/src/workflows/integration/scvi_leiden/config.vsh.yaml @@ -29,6 +29,11 @@ argument_groups: type: string default: "rna" required: false + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true + - name: "Outputs" arguments: - name: "--output" diff --git a/src/workflows/integration/scvi_leiden/main.nf b/src/workflows/integration/scvi_leiden/main.nf index baed2aa6c29..6b6c39144d6 100644 --- a/src/workflows/integration/scvi_leiden/main.nf +++ b/src/workflows/integration/scvi_leiden/main.nf @@ -28,6 +28,7 @@ workflow run_wf { "output_model": "output_model", "modality": "modality", "input_layer": "layer", + "sanitize_gene_names": "sanitize_gene_names" ], toState: [ "input": "output", From aec353529b760af9f05ef17998f51f8c1fdf22ca Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 9 Oct 2025 16:59:22 +0200 Subject: [PATCH 04/25] update changelog --- CHANGELOG.md | 2 ++ src/utils/set_var_index.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 233c6c955a6..329b7e93408 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ * `integrate/scarches` and `workflows/annotate/scanvi_scarches`: Enable correction for technical variability by multiple continuous and categorical covariates. +* Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Make feature name sanitation optional (PR #1084). + # openpipelines 3.0.0 ## BREAKING CHANGES diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 65149675840..822b0bc8269 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -19,7 +19,7 @@ def sanitize_gene_names(gene_names: list[str]) -> list[str]: def set_var_index( - adata: ad.AnnData, var_name: str | None = None, sanitise_gene_names: bool = True + adata: ad.AnnData, var_name: str | None = None, sanitize_gene_names: bool = True ) -> ad.AnnData: """Sanitize gene names (optional) and set the index of the .var DataFrame. @@ -29,7 +29,7 @@ def set_var_index( Annotated data object var_name : str | None Name of the column in `adata.var` that contains the gene names, if None, the existing index will be sanitized but not replaced. - sanitise_gene_names : bool + sanitize_gene_names : bool Whether to sanitize gene names by removing version numbers. Returns @@ -39,7 +39,7 @@ def set_var_index( """ gene_names = adata.var[var_name] if var_name else adata.var.index - if sanitise_gene_names: + if sanitize_gene_names: gene_names = sanitize_gene_names(gene_names) adata.var.index = gene_names From b2879ff79c548cbe457917d9d44a74675f8941ef Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 10 Oct 2025 08:31:24 +0200 Subject: [PATCH 05/25] update naming --- src/utils/set_var_index.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 822b0bc8269..8eca0214498 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -2,7 +2,7 @@ import re -def sanitize_gene_names(gene_names: list[str]) -> list[str]: +def strip_version_number(gene_names: list[str]) -> list[str]: """Sanitize gene names by removing version numbers. Parameters @@ -40,7 +40,7 @@ def set_var_index( gene_names = adata.var[var_name] if var_name else adata.var.index if sanitize_gene_names: - gene_names = sanitize_gene_names(gene_names) + gene_names = strip_version_number(gene_names) adata.var.index = gene_names From 747208a435f39a01235a572a4172dfdb8f6a3a3e Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 13:45:33 +0200 Subject: [PATCH 06/25] update descriptions component --- src/annotate/celltypist/config.vsh.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index fe5a8ed3460..8455d034d32 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -26,7 +26,7 @@ argument_groups: required: false - name: "--input_layer" type: string - description: The layer in the input data containing raw counts to be used for cell type annotation if .X is not to be used. + description: The layer in the input data containing counts that are lognormalized to 10000, .X is not to be used. - name: "--input_var_gene_names" type: string required: false @@ -50,7 +50,7 @@ argument_groups: required: false - name: "--reference_layer" type: string - description: The layer in the reference data containing raw counts to be used for cell type annotation if .X is not to be used. + description: The layer in the reference data containing counts that are lognormalized to 10000, if .X is not to be used. required: false - name: "--reference_obs_target" type: string From dbe3a51138b741fc721ddbdad8c9fb2bc7a4ad19 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 13:46:17 +0200 Subject: [PATCH 07/25] update changelog --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44ae42f8da5..d7124cc5e76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,6 @@ * `integrate/scarches` and `workflows/annotate/scanvi_scarches`: Enable correction for technical variability by multiple continuous and categorical covariates. - ## BUG FIXES * `filter/filter_with_counts`: this component would sometimes crash (segfault) when processing malformatted sparse matrices. A proper error message is now provided in this case (PR #1086). From 8f3ca12187cad63cf734005499f068f05448ab20 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 13:47:45 +0200 Subject: [PATCH 08/25] update changelo --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03a6b35fee4..30af778d5b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,10 @@ * Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Make feature name sanitation optional (PR #1084). +## BUG FIXES + +* `filter/filter_with_counts`: this component would sometimes crash (segfault) when processing malformatted sparse matrices. A proper error message is now provided in this case (PR #1086). + # openpipelines 3.0.0 ## BREAKING CHANGES From 791a418f4c80f816c83e9d0ec967de3b9a12ef0a Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 13:58:43 +0200 Subject: [PATCH 09/25] update changelog --- CHANGELOG.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7124cc5e76..f8a47e7e576 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,3 @@ -# openpipelines x.x.x - -## BREAKING - -* `annotate/celltypist`: This component now requires to pass a raw count layer, that will be lognormalized with a target sum of 10000, the required count format for CellTypist (PR #1083). - # openpipelines 3.1.0 ## NEW FUNCTIONALITY @@ -14,6 +8,10 @@ * `convert/from_seurat_to_h5mu`: Converts a Seurat object to a MuData object (PR #1078, #1079, #1082). +* `annotate/celltypist`: enable CUDA acceleration (PR #1083). + +* `workflows/annotation/celltypist`: Performs lognormalization (target count of 10000) followed by cell type annotation using CellTypist (PR #1083). + ## EXPERIMENTAL * `differential_expression/deseq2`: Performs differential expression analysis using DESeq2 on bulk or pseudobulk datasets (PR #1044). From 21eb3e2eaf260af969a36453dc81795e1162e9c8 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 13:59:08 +0200 Subject: [PATCH 10/25] update changelog --- src/annotate/celltypist/script.py | 92 ++++++++++++++----------------- 1 file changed, 41 insertions(+), 51 deletions(-) diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index f9f61d59b9f..e3efb749d3d 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -3,20 +3,23 @@ import mudata as mu import anndata as ad import pandas as pd -import scanpy as sc +import numpy as np ## VIASH START par = { "input": "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu", "output": "output.h5mu", "modality": "rna", + # "reference": None, "reference": "resources_test/annotation_test_data/TS_Blood_filtered.h5mu", "model": None, + # "model": "resources_test/annotation_test_data/celltypist_model_Immune_All_Low.pkl", "input_layer": "log_normalized", "reference_layer": "log_normalized", "input_reference_gene_overlap": 100, "reference_obs_target": "cell_ontology_class", "reference_var_input": None, + "check_expression": False, "feature_selection": True, "majority_voting": True, "output_compression": "gzip", @@ -41,43 +44,10 @@ logger = setup_logger() -def setup_anndata( - adata: ad.AnnData, - layer: str | None = None, - gene_names: str | None = None, - var_input: str | None = None, -) -> ad.AnnData: - """Creates an AnnData object in the expected format for CellTypist, - with lognormalized data (with a target sum of 10000) in the .X slot. - - Parameters - ---------- - adata - AnnData object. - layer - Layer in AnnData object to lognormalize. - gene_names - .obs field with the gene names to be used - var_input - .var field with a boolean array of the genes to be used (e.g. highly variable genes) - Returns - ------- - AnnData object in CellTypist format. - """ - - adata = set_var_index(adata, gene_names) - - if var_input: - adata = subset_vars(adata, var_input) - - raw_counts = adata.layers[layer].copy() if layer else adata.X.copy() - - input_modality = ad.AnnData(X=raw_counts, var=pd.DataFrame(index=adata.var.index)) - - sc.pp.normalize_total(input_modality, target_sum=10000) - sc.pp.log1p(input_modality) - - return input_modality +def check_celltypist_format(indata): + if np.abs(np.expm1(indata[0]).sum() - 10000) > 1: + return False + return True def main(par): @@ -93,8 +63,17 @@ def main(par): input_modality = input_adata.copy() # Provide correct format of query data for celltypist annotation - input_modality = setup_anndata( - input_modality, par["input_layer"], par["input_var_gene_names"] + ## Sanitize gene names and set as index + input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + ## Fetch lognormalized counts + lognorm_counts = ( + input_modality.layers[par["input_layer"]].copy() + if par["input_layer"] + else input_modality.X.copy() + ) + ## Create AnnData object + input_modality = ad.AnnData( + X=lognorm_counts, var=pd.DataFrame(index=input_modality.var.index) ) if par["model"]: @@ -107,15 +86,18 @@ def main(par): ) elif par["reference"]: - reference_adata = mu.read_h5mu(par["reference"]).mod[par["modality"]] - reference_modality = reference_adata.copy() - - # Provide correct format of query data for celltypist annotation - reference_modality = setup_anndata( - reference_modality, - par["reference_layer"], - par["reference_var_gene_names"], - par["reference_var_input"], + reference_modality = mu.read_h5mu(par["reference"]).mod[par["modality"]] + + # subset to HVG if required + if par["reference_var_input"]: + reference_modality = subset_vars( + reference_modality, par["reference_var_input"] + ) + + # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) + # CellTypist requires query gene names to be in index + reference_modality = set_var_index( + reference_modality, par["reference_var_gene_names"] ) # Ensure enough overlap between genes in query and reference @@ -125,10 +107,18 @@ def main(par): min_gene_overlap=par["input_reference_gene_overlap"], ) + reference_matrix = ( + reference_modality.layers[par["reference_layer"]] + if par["reference_layer"] + else reference_modality.X + ) + + labels = reference_modality.obs[par["reference_obs_target"]] + logger.info("Training CellTypist model on reference") model = celltypist.train( - reference_modality.X, - labels=reference_adata.obs[par["reference_obs_target"]], + reference_matrix, + labels=labels, genes=reference_modality.var.index, C=par["C"], max_iter=par["max_iter"], From 917a394242fbdd3db954b2368e9a6f5a37add4f7 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 14:04:54 +0200 Subject: [PATCH 11/25] raise when duplicating gene names --- src/utils/set_var_index.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 8eca0214498..9b551641657 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -40,7 +40,16 @@ def set_var_index( gene_names = adata.var[var_name] if var_name else adata.var.index if sanitize_gene_names: + ori_gene_names = len(gene_names) gene_names = strip_version_number(gene_names) + sanitized_gene_names = len(set(gene_names)) + + assert ori_gene_names == sanitized_gene_names, ( + "Sanitizing gene names resulted in duplicated gene names.\n" + "Please ensure unique gene names before proceeding.\n" + "Please make sure --var_gene_names contains ensembl IDs (not gene symbols) " + "when --sanitize_gene_names is set to True." + ) adata.var.index = gene_names From 7a062cde88ee34b704f8c0001e080b486ced797b Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 14:11:58 +0200 Subject: [PATCH 12/25] update changelog --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8a47e7e576..97ccc220ef9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,6 @@ * `convert/from_seurat_to_h5mu`: Converts a Seurat object to a MuData object (PR #1078, #1079, #1082). -* `annotate/celltypist`: enable CUDA acceleration (PR #1083). - * `workflows/annotation/celltypist`: Performs lognormalization (target count of 10000) followed by cell type annotation using CellTypist (PR #1083). ## EXPERIMENTAL From 737f771b9f07d86748c7e375c23217a328fd359a Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 14:23:49 +0200 Subject: [PATCH 13/25] undo test changes --- src/annotate/celltypist/test.py | 54 +++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/annotate/celltypist/test.py b/src/annotate/celltypist/test.py index 52670523a91..60704b3c24c 100644 --- a/src/annotate/celltypist/test.py +++ b/src/annotate/celltypist/test.py @@ -1,6 +1,8 @@ import sys import os import pytest +import subprocess +import re import mudata as mu from openpipeline_testutils.asserters import assert_annotation_objects_equal @@ -25,8 +27,12 @@ def test_simple_execution(run_component, random_h5mu_path): [ "--input", input_file, + "--input_layer", + "log_normalized", "--reference", reference_file, + "--reference_layer", + "log_normalized", "--reference_obs_target", "cell_ontology_class", "--reference_var_gene_names", @@ -69,8 +75,12 @@ def test_set_params(run_component, random_h5mu_path): [ "--input", input_file, + "--input_layer", + "log_normalized", "--reference", reference_file, + "--reference_layer", + "log_normalized", "--reference_obs_target", "cell_ontology_class", "--reference_var_gene_names", @@ -149,5 +159,49 @@ def test_with_model(run_component, random_h5mu_path): ) +def test_fail_invalid_input_expression(run_component, random_h5mu_path): + output_file = random_h5mu_path() + + # fails because input data are not lognormalized + with pytest.raises(subprocess.CalledProcessError) as err: + run_component( + [ + "--input", + input_file, + "--reference", + reference_file, + "--reference_var_gene_names", + "ensemblid", + "--output", + output_file, + ] + ) + assert re.search( + r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell", + err.value.stdout.decode("utf-8"), + ) + + # fails because reference data are not lognormalized + with pytest.raises(subprocess.CalledProcessError) as err: + run_component( + [ + "--input", + input_file, + "--layer", + "log_normalized", + "--reference", + reference_file, + "--reference_var_gene_names", + "ensemblid", + "--output", + output_file, + ] + ) + assert re.search( + r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell", + err.value.stdout.decode("utf-8"), + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__])) From dec15513cfda7b016e2bead0e98b4a3d5a1e5d8f Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 14:48:14 +0200 Subject: [PATCH 14/25] wip --- src/annotate/celltypist/config.vsh.yaml | 2 +- src/annotate/celltypist/script.py | 8 - .../annotation/celltypist/config.vsh.yaml | 150 ++++++++++ .../annotation/celltypist/integration_test.sh | 15 + src/workflows/annotation/celltypist/main.nf | 259 ++++++++++++++++++ .../annotation/celltypist/nextflow.config | 10 + src/workflows/annotation/celltypist/test.nf | 66 +++++ .../annotation/celltypist/config.vsh.yaml | 25 ++ .../annotation/celltypist/script.py | 48 ++++ 9 files changed, 574 insertions(+), 9 deletions(-) create mode 100644 src/workflows/annotation/celltypist/config.vsh.yaml create mode 100755 src/workflows/annotation/celltypist/integration_test.sh create mode 100644 src/workflows/annotation/celltypist/main.nf create mode 100644 src/workflows/annotation/celltypist/nextflow.config create mode 100644 src/workflows/annotation/celltypist/test.nf create mode 100644 src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml create mode 100644 src/workflows/test_workflows/annotation/celltypist/script.py diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index 8455d034d32..94acd7a591b 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -152,7 +152,7 @@ engines: packages: - celltypist==1.6.3 - type: python - __merge__: [ /src/base/requirements/anndata_mudata.yaml, /src/base/requirements/scanpy.yaml, .] + __merge__: [ /src/base/requirements/anndata_mudata.yaml, .] __merge__: [ /src/base/requirements/python_test_setup.yaml, .] runners: - type: executable diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index e3efb749d3d..c83debfb913 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -3,7 +3,6 @@ import mudata as mu import anndata as ad import pandas as pd -import numpy as np ## VIASH START par = { @@ -19,7 +18,6 @@ "input_reference_gene_overlap": 100, "reference_obs_target": "cell_ontology_class", "reference_var_input": None, - "check_expression": False, "feature_selection": True, "majority_voting": True, "output_compression": "gzip", @@ -44,12 +42,6 @@ logger = setup_logger() -def check_celltypist_format(indata): - if np.abs(np.expm1(indata[0]).sum() - 10000) > 1: - return False - return True - - def main(par): if (not par["model"] and not par["reference"]) or ( par["model"] and par["reference"] diff --git a/src/workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/annotation/celltypist/config.vsh.yaml new file mode 100644 index 00000000000..1fe4a0ef444 --- /dev/null +++ b/src/workflows/annotation/celltypist/config.vsh.yaml @@ -0,0 +1,150 @@ +name: "celltypist" +namespace: "workflows/annotation" +scope: "public" +description: "Cell type annotation workflow by performing lognormalization of the raw counts layer followed by cell type annotation with CellTypist." +info: + name: "CellTypist annotation" + test_dependencies: + - name: celltypist_test + namespace: test_workflows/annotation +authors: + - __merge__: /src/authors/dorien_roosen.yaml + roles: [ author, maintainer ] + - __merge__: /src/authors/weiwei_schultz.yaml + roles: [ contributor ] + +argument_groups: + - name: Inputs + description: Input dataset (query) arguments + arguments: + - name: "--input" + alternatives: [-i] + type: file + description: The input (query) data to be labeled. Should be a .h5mu file. + direction: input + required: true + example: input.h5mu + - name: "--modality" + description: Which modality to process. + type: string + default: "rna" + required: false + - name: "--input_layer" + type: string + description: The layer in the input data containing raw counts, if .X is not to be used. + - name: "--input_var_gene_names" + type: string + required: false + description: | + The name of the adata var column in the input data containing gene names; when no gene_name_layer is provided, the var index will be used. + - name: "--input_reference_gene_overlap" + type: integer + default: 100 + min: 1 + description: | + The minimum number of genes present in both the reference and query datasets. + + - name: Reference + description: Arguments related to the reference dataset. + arguments: + - name: "--reference" + type: file + description: "The reference data to train the CellTypist classifiers on. Only required if a pre-trained --model is not provided." + example: reference.h5mu + direction: input + required: false + - name: "--reference_layer" + type: string + description: The layer in the reference data containing raw counts, if .X is not to be used. + required: false + - name: "--reference_obs_target" + type: string + description: The name of the adata obs column in the reference data containing cell type annotations. + default: "cell_ontology_class" + - name: "--reference_var_gene_names" + type: string + required: false + description: | + The name of the adata var column in the reference data containing gene names; when no gene_name_layer is provided, the var index will be used. + - name: "--reference_var_input" + type: string + required: false + description: | + .var column containing highly variable genes. By default, do not subset genes. + + - name: Model arguments + description: Model arguments. + arguments: + - name: "--model" + type: file + description: "Pretrained model in pkl format. If not provided, the model will be trained on the reference data and --reference should be provided." + required: false + example: pretrained_model.pkl + - name: "--feature_selection" + type: boolean + description: "Whether to perform feature selection." + default: false + - name: "--majority_voting" + type: boolean + description: "Whether to refine the predicted labels by running the majority voting classifier after over-clustering." + default: false + - name: "--C" + type: double + description: "Inverse of regularization strength in logistic regression." + default: 1.0 + - name: "--max_iter" + type: integer + description: "Maximum number of iterations before reaching the minimum of the cost function." + default: 1000 + - name: "--use_SGD" + type: boolean_true + description: "Whether to use the stochastic gradient descent algorithm." + - name: "--min_prop" + type: double + description: | + "For the dominant cell type within a subcluster, the minimum proportion of cells required to + support naming of the subcluster by this cell type. Ignored if majority_voting is set to False. + Subcluster that fails to pass this proportion threshold will be assigned 'Heterogeneous'." + default: 0 + + - name: Outputs + description: Output arguments. + arguments: + - name: "--output" + type: file + description: Output h5mu file. + direction: output + example: output.h5mu + - name: "--output_obs_predictions" + type: string + default: celltypist_pred + required: false + description: | + In which `.obs` slots to store the predicted information. + - name: "--output_obs_probability" + type: string + default: celltypist_probability + required: false + description: | + In which `.obs` slots to store the probability of the predictions. + __merge__: [., /src/base/h5_compression_argument.yaml] + +dependencies: + - name: transform/normalize_total + - name: transform/log1p + - name: dataflow/merge + +resources: + - type: nextflow_script + path: main.nf + entrypoint: run_wf + +test_resources: + - type: nextflow_script + path: test.nf + entrypoint: test_wf + - path: /resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu + - path: /resources_test/annotation_test_data/TS_Blood_filtered.h5mu + +runners: + - type: nextflow diff --git a/src/workflows/annotation/celltypist/integration_test.sh b/src/workflows/annotation/celltypist/integration_test.sh new file mode 100755 index 00000000000..e8597bfef3e --- /dev/null +++ b/src/workflows/annotation/celltypist/integration_test.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# get the root of the directory +REPO_ROOT=$(git rev-parse --show-toplevel) + +# ensure that the command below is run from the root of the repository +cd "$REPO_ROOT" + +nextflow \ + run . \ + -main-script src/workflows/annotation/harmony_knn/test.nf \ + -entry test_wf \ + -profile docker,no_publish \ + -c src/workflows/utils/labels_ci.config \ + -c src/workflows/utils/integration_tests.config \ diff --git a/src/workflows/annotation/celltypist/main.nf b/src/workflows/annotation/celltypist/main.nf new file mode 100644 index 00000000000..935ab3a24cd --- /dev/null +++ b/src/workflows/annotation/celltypist/main.nf @@ -0,0 +1,259 @@ +workflow run_wf { + take: + input_ch + + main: + + modalities_ch = input_ch + // Set aside the output for this workflow to avoid conflicts + | map {id, state -> + def new_state = state + ["workflow_output": state.output] + [id, new_state] + } + // Align query and reference datasets + | align_query_reference.run( + fromState: [ + "input": "input", + "modality": "modality", + "input_layer": "input_layer", + "input_obs_batch": "input_obs_batch_label", + "input_var_gene_names": "input_var_gene_names", + "reference": "reference", + "reference_layer": "reference_layer", + "reference_obs_batch": "reference_obs_batch_label", + "reference_obs_label": "reference_obs_target", + "reference_var_gene_names": "reference_var_gene_names", + "input_reference_gene_overlap": "input_reference_gene_overlap", + "overwrite_existing_key": "overwrite_existing_key" + ], + args: [ + "input_id": "query", + "reference_id": "reference", + "output_layer": "_counts", + "output_var_gene_names": "_gene_names", + "output_obs_batch": "_sample_id", + "output_obs_label": "_cell_type", + "output_obs_id": "_dataset", + "output_var_common_genes": "_common_vars" + ], + toState: [ + "input": "output_query", + "reference": "output_reference" + ] + ) + + | split_modalities.run( + fromState: {id, state -> + def newState = ["input": state.input, "id": id] + }, + toState: ["output": "output", "output_types": "output_types"] + ) + | flatMap {id, state -> + def outputDir = state.output + def types = readCsv(state.output_types.toUriString()) + + types.collect{ dat -> + // def new_id = id + "_" + dat.name + def new_id = id // it's okay because the channel will get split up anyways + def new_data = outputDir.resolve(dat.filename) + [ new_id, state + ["input": new_data, modality: dat.name]] + } + } + // Remove arguments from split modalities from state + | map {id, state -> + def keysToRemove = ["output_types", "output_files"] + def newState = state.findAll{it.key !in keysToRemove} + [id, newState] + } + | view {"After splitting modalities: $it"} + + + rna_ch = modalities_ch + + | filter{id, state -> state.modality == "rna"} + + // Concatenate query and reference datasets prior to integration + // Only concatenate rna modality in this channel + | concatenate_h5mu.run( + fromState: { id, state -> [ + "input": [state.input, state.reference] + ] + }, + args: [ + "input_id": ["query", "reference"], + "modality": "rna", + "other_axis_mode": "move" + ], + toState: ["input": "output"] + ) + | view {"After concatenation: $it"} + | highly_variable_features_scanpy.run( + fromState: [ + "input": "input", + "modality": "modality", + "n_top_features": "n_hvg" + ], + args: [ + "layer": "_counts", + "var_input": "_common_vars", + "var_name_filter": "_common_hvg", + "obs_batch_key": "_sample_id" + ], + toState: [ + "input": "output" + ] + ) + | pca.run( + fromState: [ + "input": "input", + "modality": "modality", + "overwrite": "overwrite_existing_key", + "num_compontents": "pca_num_components" + ], + args: [ + "layer": "_counts", + "var_input": "_common_hvg", + "obsm_output": "X_pca_query_reference", + "varm_output": "pca_loadings_query_reference", + "uns_output": "pca_variance_query_reference", + ], + toState: [ + "input": "output" + ] + ) + | delete_layer.run( + key: "delete_aligned_lognormalized_counts_layer", + fromState: [ + "input": "input", + "modality": "modality", + ], + args: [ + "layer": "_counts", + "missing_ok": "true" + ], + toState: [ + "input": "output" + ] + ) + // Run harmony integration with leiden clustering + | harmony_leiden_workflow.run( + fromState: { id, state -> [ + "id": id, + "input": state.input, + "modality": state.modality, + "obsm_integrated": state.output_obsm_integrated, + "theta": state.harmony_theta, + "leiden_resolution": state.leiden_resolution, + ]}, + args: [ + "embedding": "X_pca_query_reference", + "uns_neighbors": "harmonypy_integration_neighbors", + "obsp_neighbor_distances": "harmonypy_integration_distances", + "obsp_neighbor_connectivities": "harmonypy_integration_connectivities", + "obs_cluster": "harmony_integration_leiden", + "obsm_umap": "X_leiden_harmony_umap", + "obs_covariates": "_sample_id" + ], + toState: ["input": "output"] + ) + | view {"After integration: $it"} + // Split integrated dataset back into a separate reference and query dataset + | split_h5mu.run( + fromState: [ + "input": "input", + "modality": "modality" + ], + args: [ + "obs_feature": "_dataset", + "output_files": "sample_files.csv", + "drop_obs_nan": "true", + "output": "ref_query" + ], + toState: [ + "output": "output", + "output_files": "output_files" + ], + auto: [ publish: true ] + ) + | view {"After sample splitting: $it"} + // map the integrated query and reference datasets back to the state + | map {id, state -> + def outputDir = state.output + if (workflow.stubRun) { + def output_files = outputDir.listFiles() + def new_state = state + [ + "input": output_files[0], + "reference": output_files[1], + ] + return [id, new_state] + } + def files = readCsv(state.output_files.toUriString()) + def query_file = files.findAll{ dat -> dat.name == 'query' } + assert query_file.size() == 1, 'there should only be one query file' + def reference_file = files.findAll{ dat -> dat.name == 'reference' } + assert reference_file.size() == 1, 'there should only be one reference file' + def integrated_query = outputDir.resolve(query_file.filename) + def integrated_reference = outputDir.resolve(reference_file.filename) + def newKeys = ["input": integrated_query, "reference": integrated_reference] + [id, state + newKeys] + } + // remove keys from split files + | map {id, state -> + def keysToRemove = ["output_files"] + def newState = state.findAll{it.key !in keysToRemove} + [id, newState] + } + // Perform KNN label transfer from integrated reference to integrated query + | knn.run( + fromState: [ + "input": "input", + "modality": "modality", + "input_obsm_features": "output_obsm_integrated", + "reference": "reference", + "reference_obsm_features": "output_obsm_integrated", + "reference_obs_targets": "reference_obs_target", + "output_obs_predictions": "output_obs_predictions", + "output_obs_probability": "output_obs_probability", + "output_compression": "output_compression", + "weights": "knn_weights", + "n_neighbors": "knn_n_neighbors", + "output": "workflow_output" + ], + toState: ["input": "output"] + // toState: {id, output, state -> ["output": output.output]} + ) + | view {"After processing RNA modality: $it"} + + other_mod_ch = modalities_ch + | filter{id, state -> state.modality != "rna"} + + output_ch = rna_ch.mix(other_mod_ch) + | groupTuple(by: 0, sort: "hash") + | map { id, states -> + def new_input = states.collect{it.input} + def modalities = states.collect{it.modality}.unique() + def other_state_keys = states.inject([].toSet()){ current_keys, state -> + def new_keys = current_keys + state.keySet() + return new_keys + }.minus(["output", "input", "modality", "reference"]) + def new_state = other_state_keys.inject([:]){ old_state, argument_name -> + argument_values = states.collect{it.get(argument_name)}.unique() + assert argument_values.size() == 1, "Arguments should be the same across modalities. Please report this \ + as a bug. Argument name: $argument_name, \ + argument value: $argument_values" + def argument_value + argument_values.each { argument_value = it } + def current_state = old_state + [(argument_name): argument_value] + return current_state + } + [id, new_state + ["input": new_input, "modalities": modalities]] + } + | merge.run( + fromState: ["input": "input"], + toState: ["output": "output"], + ) + | setState(["output"]) + + emit: + output_ch +} diff --git a/src/workflows/annotation/celltypist/nextflow.config b/src/workflows/annotation/celltypist/nextflow.config new file mode 100644 index 00000000000..059100c489c --- /dev/null +++ b/src/workflows/annotation/celltypist/nextflow.config @@ -0,0 +1,10 @@ +manifest { + nextflowVersion = '!>=20.12.1-edge' +} + +params { + rootDir = java.nio.file.Paths.get("$projectDir/../../../../").toAbsolutePath().normalize().toString() +} + +// include common settings +includeConfig("${params.rootDir}/src/workflows/utils/labels.config") diff --git a/src/workflows/annotation/celltypist/test.nf b/src/workflows/annotation/celltypist/test.nf new file mode 100644 index 00000000000..cbe7833fb20 --- /dev/null +++ b/src/workflows/annotation/celltypist/test.nf @@ -0,0 +1,66 @@ +nextflow.enable.dsl=2 + +include { harmony_knn } from params.rootDir + "/target/nextflow/workflows/annotation/harmony_knn/main.nf" +include { harmony_knn_test } from params.rootDir + "/target/_test/nextflow/test_workflows/annotation/harmony_knn_test/main.nf" +params.resources_test = params.rootDir + "/resources_test" + +workflow test_wf { + // allow changing the resources_test dir + resources_test = file(params.resources_test) + + output_ch = Channel.fromList( + [ + [ + id: "simple_execution_test", + input: resources_test.resolve("pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu"), + input_layer: "log_normalized", + reference: resources_test.resolve("annotation_test_data/TS_Blood_filtered.h5mu"), + reference_var_gene_names: "ensemblid", + input_obs_batch_label: "sample_id", + reference_layer: "log_normalized", + reference_obs_batch_label: "donor_assay", + reference_obs_target: "cell_type", + leiden_resolution: [1.0, 0.25] + ], + [ + id: "no_leiden_resolutions_test", + input: resources_test.resolve("pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu"), + input_layer: "log_normalized", + reference: resources_test.resolve("annotation_test_data/TS_Blood_filtered.h5mu"), + reference_var_gene_names: "ensemblid", + input_obs_batch_label: "sample_id", + reference_layer: "log_normalized", + reference_obs_batch_label: "donor_assay", + reference_obs_target: "cell_type", + leiden_resolution: [] + ] + ]) + | map{ state -> [state.id, state] } + | harmony_knn + | view { output -> + assert output.size() == 2 : "Outputs should contain two elements; [id, state]" + + // check id + def id = output[0] + assert id.endsWith("_test") : "Output ID should be same as input ID" + + // check output + def state = output[1] + assert state instanceof Map : "State should be a map. Found: ${state}" + assert state.containsKey("output") : "Output should contain key 'output'." + assert state.output.isFile() : "'output' should be a file." + assert state.output.toString().endsWith(".h5mu") : "Output file should end with '.h5mu'. Found: ${state.output}" + + "Output: $output" + } + | harmony_knn_test.run( + fromState: [ + "input": "output" + ] + ) + | toSortedList({a, b -> a[0] <=> b[0]}) + | map { output_list -> + assert output_list.size() == 2 : "output channel should contain 2 events" + assert output_list.collect{it[0]} == ["no_leiden_resolutions_test", "simple_execution_test"] + } + } \ No newline at end of file diff --git a/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml new file mode 100644 index 00000000000..dbd3a464dbf --- /dev/null +++ b/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml @@ -0,0 +1,25 @@ +name: "celltypist_test" +namespace: "test_workflows/annotation" +scope: "test" +description: "This component tests the output of the annotation of the celltypist workflow." +authors: + - __merge__: /src/authors/dorien_roosen.yaml +argument_groups: + - name: Inputs + arguments: + - name: "--input" + type: file + required: true + description: Path to h5mu output. + example: foo.final.h5mu +resources: + - type: python_script + path: script.py + - path: /src/utils/setup_logger.py +engines: + - type: docker + image: python:3.12-slim + __merge__: /src/base/requirements/testworkflows_setup.yaml +runners: + - type: executable + - type: nextflow \ No newline at end of file diff --git a/src/workflows/test_workflows/annotation/celltypist/script.py b/src/workflows/test_workflows/annotation/celltypist/script.py new file mode 100644 index 00000000000..2e1a71b2613 --- /dev/null +++ b/src/workflows/test_workflows/annotation/celltypist/script.py @@ -0,0 +1,48 @@ +from mudata import read_h5mu +import sys +import pytest + +##VIASH START +par = {"input": "harmony_knn/output.h5mu"} + +meta = {"resources_dir": "resources_test"} +##VIASH END + + +def test_run(): + input_mudata = read_h5mu(par["input"]) + expected_obsm = ["X_integrated_harmony", "X_leiden_harmony_umap"] + expected_obs = ["cell_type_pred", "cell_type_probability"] + expected_obsp = [ + "harmonypy_integration_distances", + "harmonypy_integration_connectivities", + ] + expected_mod = ["rna", "prot"] + + assert all(key in list(input_mudata.mod) for key in expected_mod), ( + f"Input modalities should be: {expected_mod}, found: {input_mudata.mod.keys()}." + ) + assert all(key in list(input_mudata.mod["rna"].obsm) for key in expected_obsm), ( + f"Input mod['rna'] obsm columns should be: {expected_obsm}, found: {input_mudata.mod['rna'].obsm.keys()}." + ) + assert all(key in list(input_mudata.mod["rna"].obs) for key in expected_obs), ( + f"Input mod['rna'] obs columns should be: {expected_obs}, found: {input_mudata.mod['rna'].obs.keys()}." + ) + assert all(key in list(input_mudata.mod["rna"].obsp) for key in expected_obsp), ( + f"Input mod['rna'] obsp columns should be: {expected_obsp}, found: {input_mudata.mod['rna'].obsp.keys()}." + ) + + assert input_mudata.mod["rna"].obs["cell_type_pred"].dtype == "category", ( + "Cell type predictions should be of dtype category." + ) + assert input_mudata.mod["rna"].obs["cell_type_probability"].dtype == "float64", ( + "Cell type probabilities should be of dtype float64." + ) + + assert input_mudata.mod["rna"].shape[0] == input_mudata.mod["prot"].shape[0], ( + "Number of observations should be equal in all modalities." + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "--import-mode=importlib"])) From 9a01efe99f00d72d030945ac83bb87c137651336 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 14:56:41 +0200 Subject: [PATCH 15/25] only sanitize for ensemble ids --- src/utils/set_var_index.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 9b551641657..3fbdeaedf51 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -3,19 +3,27 @@ def strip_version_number(gene_names: list[str]) -> list[str]: - """Sanitize gene names by removing version numbers. + """Sanitize ensemble ID's by removing version numbers. Parameters ---------- gene_names : list[str] - List of gene names to sanitize. + List of ensemble ID's to sanitize. Returns ------- list[str] - List of sanitized gene names. + List of sanitized ensemble ID's. """ - return [re.sub("\\.[0-9]+$", "", s) for s in gene_names] + + # Pattern matches Ensembl IDs: starts with ENS, followed by any characters, + # then an eleven digit number, optionally followed by .version_number + ensembl_pattern = re.compile(r"^(ENS.*\d{11})(?:\.\d+)?$") + + return [ + match.group(1) if (match := ensembl_pattern.match(gene)) else gene + for gene in gene_names + ] def set_var_index( From 2a11b496fe31e8d8d439e094fe87fd994527d913 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 16:41:15 +0200 Subject: [PATCH 16/25] create celltypist workflow --- .../annotation/celltypist/config.vsh.yaml | 6 +- .../annotation/celltypist/integration_test.sh | 10 +- src/workflows/annotation/celltypist/main.nf | 342 ++++++------------ src/workflows/annotation/celltypist/test.nf | 74 ++-- .../annotation/celltypist/config.vsh.yaml | 6 + .../annotation/celltypist/script.py | 27 +- 6 files changed, 182 insertions(+), 283 deletions(-) diff --git a/src/workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/annotation/celltypist/config.vsh.yaml index 1fe4a0ef444..ab9c266f6dc 100644 --- a/src/workflows/annotation/celltypist/config.vsh.yaml +++ b/src/workflows/annotation/celltypist/config.vsh.yaml @@ -132,7 +132,9 @@ argument_groups: dependencies: - name: transform/normalize_total - name: transform/log1p - - name: dataflow/merge + - name: transform/delete_layer + - name: annotate/celltypist + alias: celltypist_component resources: - type: nextflow_script @@ -145,6 +147,8 @@ test_resources: entrypoint: test_wf - path: /resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu - path: /resources_test/annotation_test_data/TS_Blood_filtered.h5mu + - path: /resources_test/annotation_test_data/celltypist_model_Immune_All_Low.pkl + - path: /resources_test/annotation_test_data/demo_2000_cells.h5mu runners: - type: nextflow diff --git a/src/workflows/annotation/celltypist/integration_test.sh b/src/workflows/annotation/celltypist/integration_test.sh index e8597bfef3e..e5870cfc714 100755 --- a/src/workflows/annotation/celltypist/integration_test.sh +++ b/src/workflows/annotation/celltypist/integration_test.sh @@ -8,8 +8,16 @@ cd "$REPO_ROOT" nextflow \ run . \ - -main-script src/workflows/annotation/harmony_knn/test.nf \ + -main-script src/workflows/annotation/celltypist/test.nf \ -entry test_wf \ -profile docker,no_publish \ -c src/workflows/utils/labels_ci.config \ -c src/workflows/utils/integration_tests.config \ + +nextflow \ + run . \ + -main-script src/workflows/annotation/celltypist/test.nf \ + -entry test_wf_2 \ + -profile docker,no_publish \ + -c src/workflows/utils/labels_ci.config \ + -c src/workflows/utils/integration_tests.config \ diff --git a/src/workflows/annotation/celltypist/main.nf b/src/workflows/annotation/celltypist/main.nf index 935ab3a24cd..ab1080bf044 100644 --- a/src/workflows/annotation/celltypist/main.nf +++ b/src/workflows/annotation/celltypist/main.nf @@ -4,256 +4,122 @@ workflow run_wf { main: - modalities_ch = input_ch + output_ch = input_ch // Set aside the output for this workflow to avoid conflicts | map {id, state -> - def new_state = state + ["workflow_output": state.output] - [id, new_state] + def new_state = state + ["workflow_output": state.output] + [id, new_state] } - // Align query and reference datasets - | align_query_reference.run( - fromState: [ - "input": "input", - "modality": "modality", - "input_layer": "input_layer", - "input_obs_batch": "input_obs_batch_label", - "input_var_gene_names": "input_var_gene_names", - "reference": "reference", - "reference_layer": "reference_layer", - "reference_obs_batch": "reference_obs_batch_label", - "reference_obs_label": "reference_obs_target", - "reference_var_gene_names": "reference_var_gene_names", - "input_reference_gene_overlap": "input_reference_gene_overlap", - "overwrite_existing_key": "overwrite_existing_key" - ], - args: [ - "input_id": "query", - "reference_id": "reference", - "output_layer": "_counts", - "output_var_gene_names": "_gene_names", - "output_obs_batch": "_sample_id", - "output_obs_label": "_cell_type", - "output_obs_id": "_dataset", - "output_var_common_genes": "_common_vars" - ], - toState: [ - "input": "output_query", - "reference": "output_reference" - ] + // Log normalize query dataset to target sum of 10000 + | normalize_total.run( + fromState: { id, state -> [ + "input": state.input, + "modality": state.modality, + "input_layer": state.input_layer, + ]}, + args: [ + "output_layer": "normalized_10k", + "target_sum": "10000", + ], + toState: [ + "input": "output", + ] ) - | split_modalities.run( - fromState: {id, state -> - def newState = ["input": state.input, "id": id] - }, - toState: ["output": "output", "output_types": "output_types"] - ) - | flatMap {id, state -> - def outputDir = state.output - def types = readCsv(state.output_types.toUriString()) - - types.collect{ dat -> - // def new_id = id + "_" + dat.name - def new_id = id // it's okay because the channel will get split up anyways - def new_data = outputDir.resolve(dat.filename) - [ new_id, state + ["input": new_data, modality: dat.name]] - } - } - // Remove arguments from split modalities from state - | map {id, state -> - def keysToRemove = ["output_types", "output_files"] - def newState = state.findAll{it.key !in keysToRemove} - [id, newState] - } - | view {"After splitting modalities: $it"} - - - rna_ch = modalities_ch - - | filter{id, state -> state.modality == "rna"} - - // Concatenate query and reference datasets prior to integration - // Only concatenate rna modality in this channel - | concatenate_h5mu.run( - fromState: { id, state -> [ - "input": [state.input, state.reference] - ] - }, - args: [ - "input_id": ["query", "reference"], - "modality": "rna", - "other_axis_mode": "move" - ], - toState: ["input": "output"] - ) - | view {"After concatenation: $it"} - | highly_variable_features_scanpy.run( - fromState: [ - "input": "input", - "modality": "modality", - "n_top_features": "n_hvg" - ], - args: [ - "layer": "_counts", - "var_input": "_common_vars", - "var_name_filter": "_common_hvg", - "obs_batch_key": "_sample_id" - ], - toState: [ - "input": "output" - ] - ) - | pca.run( - fromState: [ - "input": "input", - "modality": "modality", - "overwrite": "overwrite_existing_key", - "num_compontents": "pca_num_components" - ], - args: [ - "layer": "_counts", - "var_input": "_common_hvg", - "obsm_output": "X_pca_query_reference", - "varm_output": "pca_loadings_query_reference", - "uns_output": "pca_variance_query_reference", - ], - toState: [ - "input": "output" - ] + | log1p.run( + fromState: { id, state -> [ + "input": state.input, + "modality": state.modality + ]}, + args: [ + "input_layer": "normalized_10k", + "output_layer": "log_normalized_10k", + ], + toState: [ + "input": "output" + ] ) | delete_layer.run( - key: "delete_aligned_lognormalized_counts_layer", - fromState: [ - "input": "input", - "modality": "modality", - ], - args: [ - "layer": "_counts", - "missing_ok": "true" - ], - toState: [ - "input": "output" - ] + fromState: { id, state -> [ + "input": state.input, + "modality": state.modality + ]}, + args: [ + "layer": "normalized_10k" + ], + toState: [ + "input": "output" + ] ) - // Run harmony integration with leiden clustering - | harmony_leiden_workflow.run( - fromState: { id, state -> [ - "id": id, - "input": state.input, - "modality": state.modality, - "obsm_integrated": state.output_obsm_integrated, - "theta": state.harmony_theta, - "leiden_resolution": state.leiden_resolution, - ]}, - args: [ - "embedding": "X_pca_query_reference", - "uns_neighbors": "harmonypy_integration_neighbors", - "obsp_neighbor_distances": "harmonypy_integration_distances", - "obsp_neighbor_connectivities": "harmonypy_integration_connectivities", - "obs_cluster": "harmony_integration_leiden", - "obsm_umap": "X_leiden_harmony_umap", - "obs_covariates": "_sample_id" - ], - toState: ["input": "output"] + // Log normalize reference dataset to target sum of 10000 + | normalize_total.run( + key: "normalize_total_reference", + runIf: { id, state -> + state.reference + }, + fromState: { id, state -> [ + "input": state.reference, + "modality": state.modality, + "input_layer": state.reference_layer, + ]}, + args: [ + "output_layer": "normalized_10k", + "target_sum": "10000", + ], + toState: [ + "reference": "output", + ] ) - | view {"After integration: $it"} - // Split integrated dataset back into a separate reference and query dataset - | split_h5mu.run( - fromState: [ - "input": "input", - "modality": "modality" - ], - args: [ - "obs_feature": "_dataset", - "output_files": "sample_files.csv", - "drop_obs_nan": "true", - "output": "ref_query" - ], - toState: [ - "output": "output", - "output_files": "output_files" - ], - auto: [ publish: true ] + | log1p.run( + key: "log1p_reference", + runIf: { id, state -> + state.reference + }, + fromState: { id, state -> [ + "input": state.reference, + "modality": state.modality + ]}, + args: [ + "input_layer": "normalized_10k", + "output_layer": "log_normalized_10k", + ], + toState: [ + "reference": "output" + ] ) - | view {"After sample splitting: $it"} - // map the integrated query and reference datasets back to the state - | map {id, state -> - def outputDir = state.output - if (workflow.stubRun) { - def output_files = outputDir.listFiles() - def new_state = state + [ - "input": output_files[0], - "reference": output_files[1], - ] - return [id, new_state] - } - def files = readCsv(state.output_files.toUriString()) - def query_file = files.findAll{ dat -> dat.name == 'query' } - assert query_file.size() == 1, 'there should only be one query file' - def reference_file = files.findAll{ dat -> dat.name == 'reference' } - assert reference_file.size() == 1, 'there should only be one reference file' - def integrated_query = outputDir.resolve(query_file.filename) - def integrated_reference = outputDir.resolve(reference_file.filename) - def newKeys = ["input": integrated_query, "reference": integrated_reference] - [id, state + newKeys] - } - // remove keys from split files - | map {id, state -> - def keysToRemove = ["output_files"] - def newState = state.findAll{it.key !in keysToRemove} - [id, newState] - } - // Perform KNN label transfer from integrated reference to integrated query - | knn.run( - fromState: [ - "input": "input", - "modality": "modality", - "input_obsm_features": "output_obsm_integrated", - "reference": "reference", - "reference_obsm_features": "output_obsm_integrated", - "reference_obs_targets": "reference_obs_target", - "output_obs_predictions": "output_obs_predictions", - "output_obs_probability": "output_obs_probability", - "output_compression": "output_compression", - "weights": "knn_weights", - "n_neighbors": "knn_n_neighbors", - "output": "workflow_output" - ], - toState: ["input": "output"] - // toState: {id, output, state -> ["output": output.output]} + // Run harmony integration with leiden clustering + | celltypist_component.run( + fromState: { id, state -> [ + "input": state.input, + "modality": state.modality, + "input_var_gene_names": state.input_var_gene_names, + "input_reference_gene_overlap": state.input_reference_gene_overlap, + "reference": state.reference, + "reference_obs_target": state.reference_obs_target, + "reference_var_gene_names": state.reference_var_gene_names, + "reference_var_input": state.reference_var_input, + "model": state.model, + "feature_selection": state.feature_selection, + "majority_voting": state.majority_voting, + "C": state.C, + "max_iter": state.max_iter, + "use_SGD": state.use_SGD, + "min_prop": state.min_prop, + "output": state.workflow_output, + "output_obs_predictions": state.output_obs_predictions, + "output_obs_probability": state.output_obs_probability + ]}, + args: [ + "input_layer": "log_normalized_10k", + "reference_layer": "log_normalized_10k" + ], + toState: [ + "output": "output" + ] ) - | view {"After processing RNA modality: $it"} - - other_mod_ch = modalities_ch - | filter{id, state -> state.modality != "rna"} + | view {"After annotation: $it"} + | setState(["output"]) - output_ch = rna_ch.mix(other_mod_ch) - | groupTuple(by: 0, sort: "hash") - | map { id, states -> - def new_input = states.collect{it.input} - def modalities = states.collect{it.modality}.unique() - def other_state_keys = states.inject([].toSet()){ current_keys, state -> - def new_keys = current_keys + state.keySet() - return new_keys - }.minus(["output", "input", "modality", "reference"]) - def new_state = other_state_keys.inject([:]){ old_state, argument_name -> - argument_values = states.collect{it.get(argument_name)}.unique() - assert argument_values.size() == 1, "Arguments should be the same across modalities. Please report this \ - as a bug. Argument name: $argument_name, \ - argument value: $argument_values" - def argument_value - argument_values.each { argument_value = it } - def current_state = old_state + [(argument_name): argument_value] - return current_state - } - [id, new_state + ["input": new_input, "modalities": modalities]] - } - | merge.run( - fromState: ["input": "input"], - toState: ["output": "output"], - ) - | setState(["output"]) - emit: output_ch } diff --git a/src/workflows/annotation/celltypist/test.nf b/src/workflows/annotation/celltypist/test.nf index cbe7833fb20..3ff541fe65f 100644 --- a/src/workflows/annotation/celltypist/test.nf +++ b/src/workflows/annotation/celltypist/test.nf @@ -1,7 +1,7 @@ nextflow.enable.dsl=2 -include { harmony_knn } from params.rootDir + "/target/nextflow/workflows/annotation/harmony_knn/main.nf" -include { harmony_knn_test } from params.rootDir + "/target/_test/nextflow/test_workflows/annotation/harmony_knn_test/main.nf" +include { celltypist } from params.rootDir + "/target/nextflow/workflows/annotation/celltypist/main.nf" +include { celltypist_test } from params.rootDir + "/target/_test/nextflow/test_workflows/annotation/celltypist_test/main.nf" params.resources_test = params.rootDir + "/resources_test" workflow test_wf { @@ -11,32 +11,17 @@ workflow test_wf { output_ch = Channel.fromList( [ [ - id: "simple_execution_test", + id: "reference_dataset_test", input: resources_test.resolve("pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu"), - input_layer: "log_normalized", reference: resources_test.resolve("annotation_test_data/TS_Blood_filtered.h5mu"), reference_var_gene_names: "ensemblid", input_obs_batch_label: "sample_id", - reference_layer: "log_normalized", reference_obs_batch_label: "donor_assay", reference_obs_target: "cell_type", - leiden_resolution: [1.0, 0.25] - ], - [ - id: "no_leiden_resolutions_test", - input: resources_test.resolve("pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu"), - input_layer: "log_normalized", - reference: resources_test.resolve("annotation_test_data/TS_Blood_filtered.h5mu"), - reference_var_gene_names: "ensemblid", - input_obs_batch_label: "sample_id", - reference_layer: "log_normalized", - reference_obs_batch_label: "donor_assay", - reference_obs_target: "cell_type", - leiden_resolution: [] ] ]) | map{ state -> [state.id, state] } - | harmony_knn + | celltypist | view { output -> assert output.size() == 2 : "Outputs should contain two elements; [id, state]" @@ -53,14 +38,53 @@ workflow test_wf { "Output: $output" } - | harmony_knn_test.run( + | celltypist_test.run( fromState: [ "input": "output" + ], + args: [ + "expected_modalities": ["rna", "prot"] ] ) - | toSortedList({a, b -> a[0] <=> b[0]}) - | map { output_list -> - assert output_list.size() == 2 : "output channel should contain 2 events" - assert output_list.collect{it[0]} == ["no_leiden_resolutions_test", "simple_execution_test"] + } + +workflow test_wf_2 { + // allow changing the resources_test dir + resources_test = file(params.resources_test) + + output_ch = Channel.fromList( + [ + [ + id: "reference_model_test", + input: resources_test.resolve("annotation_test_data/demo_2000_cells.h5mu"), + model: resources_test.resolve("annotation_test_data/celltypist_model_Immune_All_Low.pkl"), + reference_obs_target: "cell_type", + ], + ]) + | map{ state -> [state.id, state] } + | celltypist + | view { output -> + assert output.size() == 2 : "Outputs should contain two elements; [id, state]" + + // check id + def id = output[0] + assert id.endsWith("_test") : "Output ID should be same as input ID" + + // check output + def state = output[1] + assert state instanceof Map : "State should be a map. Found: ${state}" + assert state.containsKey("output") : "Output should contain key 'output'." + assert state.output.isFile() : "'output' should be a file." + assert state.output.toString().endsWith(".h5mu") : "Output file should end with '.h5mu'. Found: ${state.output}" + + "Output: $output" } - } \ No newline at end of file + | celltypist_test.run( + fromState: [ + "input": "output" + ], + args: [ + "expected_modalities": ["rna"] + ] + ) + } diff --git a/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml index dbd3a464dbf..340ffaca47f 100644 --- a/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml +++ b/src/workflows/test_workflows/annotation/celltypist/config.vsh.yaml @@ -12,6 +12,12 @@ argument_groups: required: true description: Path to h5mu output. example: foo.final.h5mu + - name: "--expected_modalities" + type: string + multiple: true + required: true + description: List of expected modalities in the output h5mu. + example: ["rna", "prot"] resources: - type: python_script path: script.py diff --git a/src/workflows/test_workflows/annotation/celltypist/script.py b/src/workflows/test_workflows/annotation/celltypist/script.py index 2e1a71b2613..baaaa5a25b3 100644 --- a/src/workflows/test_workflows/annotation/celltypist/script.py +++ b/src/workflows/test_workflows/annotation/celltypist/script.py @@ -11,37 +11,28 @@ def test_run(): input_mudata = read_h5mu(par["input"]) - expected_obsm = ["X_integrated_harmony", "X_leiden_harmony_umap"] - expected_obs = ["cell_type_pred", "cell_type_probability"] - expected_obsp = [ - "harmonypy_integration_distances", - "harmonypy_integration_connectivities", - ] - expected_mod = ["rna", "prot"] + expected_obs = ["celltypist_pred", "celltypist_probability"] + expected_mod = par["expected_modalities"] assert all(key in list(input_mudata.mod) for key in expected_mod), ( f"Input modalities should be: {expected_mod}, found: {input_mudata.mod.keys()}." ) - assert all(key in list(input_mudata.mod["rna"].obsm) for key in expected_obsm), ( - f"Input mod['rna'] obsm columns should be: {expected_obsm}, found: {input_mudata.mod['rna'].obsm.keys()}." - ) assert all(key in list(input_mudata.mod["rna"].obs) for key in expected_obs), ( f"Input mod['rna'] obs columns should be: {expected_obs}, found: {input_mudata.mod['rna'].obs.keys()}." ) - assert all(key in list(input_mudata.mod["rna"].obsp) for key in expected_obsp), ( - f"Input mod['rna'] obsp columns should be: {expected_obsp}, found: {input_mudata.mod['rna'].obsp.keys()}." - ) - assert input_mudata.mod["rna"].obs["cell_type_pred"].dtype == "category", ( + assert input_mudata.mod["rna"].obs["celltypist_pred"].dtype == "category", ( "Cell type predictions should be of dtype category." ) - assert input_mudata.mod["rna"].obs["cell_type_probability"].dtype == "float64", ( + assert input_mudata.mod["rna"].obs["celltypist_probability"].dtype == "float64", ( "Cell type probabilities should be of dtype float64." ) - assert input_mudata.mod["rna"].shape[0] == input_mudata.mod["prot"].shape[0], ( - "Number of observations should be equal in all modalities." - ) + if len(expected_mod) == 2: + assert ( + input_mudata.mod[expected_mod[0]].shape[0] + == input_mudata.mod[expected_mod[1]].shape[0] + ), "Number of observations should be equal in all modalities." if __name__ == "__main__": From 14826d537c5c3b66ea2ee834c0712a70d0332204 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 17 Oct 2025 16:56:24 +0200 Subject: [PATCH 17/25] update onclass tests --- src/annotate/onclass/test.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/annotate/onclass/test.py b/src/annotate/onclass/test.py index 9212928df2d..905b6b7c6b5 100644 --- a/src/annotate/onclass/test.py +++ b/src/annotate/onclass/test.py @@ -29,12 +29,12 @@ def test_simple_execution(run_component, random_h5mu_path): [ "--input", input_file, - "--input_var_gene_names", - "gene_symbol", "--reference", reference_file, "--reference_obs_target", "cell_ontology_class", + "--reference_var_gene_names", + "ensemblid", "--cl_nlp_emb_file", cl_nlp_emb_file, "--cl_ontology_file", @@ -70,12 +70,12 @@ def test_custom_obs(run_component, random_h5mu_path): [ "--input", input_file, - "--input_var_gene_names", - "gene_symbol", "--reference", reference_file, "--reference_obs_target", "cell_ontology_class", + "--reference_var_gene_names", + "ensemblid", "--output_obs_predictions", "dummy_pred_1", "--output_obs_probability", @@ -116,8 +116,6 @@ def test_no_model_no_reference_error(run_component, random_h5mu_path): [ "--input", input_file, - "--input_var_gene_names", - "gene_symbol", "--output", output_file, "--cl_nlp_emb_file", @@ -128,6 +126,8 @@ def test_no_model_no_reference_error(run_component, random_h5mu_path): cl_obo_file, "--reference_obs_target", "cell_ontology_class", + "--reference_var_gene_names", + "ensemblid", ] ) assert re.search( @@ -145,6 +145,8 @@ def test_pretrained_model(run_component, random_h5mu_path): input_file, "--input_var_gene_names", "gene_symbol", + "--sanitize_gene_names", + "False", "--cl_nlp_emb_file", cl_nlp_emb_file, "--cl_ontology_file", From d3bed1904c4e33e35ac5b1c1182776da7fe93d28 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Tue, 21 Oct 2025 08:42:37 +0200 Subject: [PATCH 18/25] update celltypist workflow --- src/workflows/annotation/celltypist/config.vsh.yaml | 4 ++++ src/workflows/annotation/celltypist/main.nf | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/annotation/celltypist/config.vsh.yaml index ab9c266f6dc..90e4519d628 100644 --- a/src/workflows/annotation/celltypist/config.vsh.yaml +++ b/src/workflows/annotation/celltypist/config.vsh.yaml @@ -43,6 +43,10 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. + - name: "--sanitize_gene_names" + type: boolean + description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + default: true - name: Reference description: Arguments related to the reference dataset. diff --git a/src/workflows/annotation/celltypist/main.nf b/src/workflows/annotation/celltypist/main.nf index ab1080bf044..c810d79c3f6 100644 --- a/src/workflows/annotation/celltypist/main.nf +++ b/src/workflows/annotation/celltypist/main.nf @@ -107,7 +107,8 @@ workflow run_wf { "min_prop": state.min_prop, "output": state.workflow_output, "output_obs_predictions": state.output_obs_predictions, - "output_obs_probability": state.output_obs_probability + "output_obs_probability": state.output_obs_probability, + "sanitize_gene_names": state.sanitize_gene_names ]}, args: [ "input_layer": "log_normalized_10k", From e98dc1049622359a7ed0dd0e5715e68df2d6d864 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Tue, 28 Oct 2025 16:02:40 +0100 Subject: [PATCH 19/25] fixup --- src/annotate/celltypist/config.vsh.yaml | 5 ----- src/annotate/celltypist/script.py | 16 ++++++---------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index 3b32cbaad4a..0582331230e 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -65,11 +65,6 @@ argument_groups: required: false description: | The name of the adata var column in the reference data containing gene names; when no gene_name_layer is provided, the var index will be used. - - name: "--reference_var_input" - type: string - required: false - description: | - .var column containing highly variable genes. By default, do not subset genes. - name: Model arguments description: Model arguments. diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index c83debfb913..85a1ec3b4dd 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -17,7 +17,6 @@ "reference_layer": "log_normalized", "input_reference_gene_overlap": 100, "reference_obs_target": "cell_ontology_class", - "reference_var_input": None, "feature_selection": True, "majority_voting": True, "output_compression": "gzip", @@ -37,7 +36,6 @@ from setup_logger import setup_logger from cross_check_genes import cross_check_genes from set_var_index import set_var_index -from subset_vars import subset_vars logger = setup_logger() @@ -56,7 +54,9 @@ def main(par): # Provide correct format of query data for celltypist annotation ## Sanitize gene names and set as index - input_modality = set_var_index(input_modality, par["input_var_gene_names"]) + input_modality = set_var_index( + input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + ) ## Fetch lognormalized counts lognorm_counts = ( input_modality.layers[par["input_layer"]].copy() @@ -80,16 +80,12 @@ def main(par): elif par["reference"]: reference_modality = mu.read_h5mu(par["reference"]).mod[par["modality"]] - # subset to HVG if required - if par["reference_var_input"]: - reference_modality = subset_vars( - reference_modality, par["reference_var_input"] - ) - # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) # CellTypist requires query gene names to be in index reference_modality = set_var_index( - reference_modality, par["reference_var_gene_names"] + reference_modality, + par["reference_var_gene_names"], + par["sanitize_gene_names"], ) # Ensure enough overlap between genes in query and reference From 62d7c770557bd0ed65dbd2512bcbc53d8747ecb7 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 30 Oct 2025 15:55:21 +0100 Subject: [PATCH 20/25] update parameter name --- src/annotate/celltypist/config.vsh.yaml | 4 ++-- src/annotate/celltypist/script.py | 4 ++-- src/annotate/onclass/config.vsh.yaml | 4 ++-- src/annotate/onclass/script.py | 4 ++-- src/annotate/onclass/test.py | 2 +- .../random_forest_annotation/config.vsh.yaml | 4 ++-- .../random_forest_annotation/script.py | 4 ++-- src/annotate/scanvi/config.vsh.yaml | 4 ++-- src/annotate/scanvi/script.py | 2 +- src/annotate/singler/config.vsh.yaml | 5 ++++- src/annotate/singler/script.R | 21 ++++++++++++++----- src/annotate/singler/test.py | 4 ++-- src/annotate/svm_annotation/config.vsh.yaml | 4 ++-- src/annotate/svm_annotation/script.py | 4 ++-- src/integrate/scarches/config.vsh.yaml | 4 ++-- src/integrate/scarches/script.py | 2 +- src/integrate/scvi/config.vsh.yaml | 4 ++-- src/integrate/scvi/script.py | 2 +- src/utils/set_var_index.py | 8 +++---- .../annotation/celltypist/config.vsh.yaml | 4 ++-- src/workflows/annotation/celltypist/main.nf | 2 +- .../scanvi_scarches/config.vsh.yaml | 4 ++-- .../annotation/scanvi_scarches/main.nf | 6 +++--- .../annotation/scvi_knn/config.vsh.yaml | 4 ++-- src/workflows/annotation/scvi_knn/main.nf | 2 +- .../integration/scvi_leiden/config.vsh.yaml | 4 ++-- src/workflows/integration/scvi_leiden/main.nf | 2 +- 27 files changed, 66 insertions(+), 52 deletions(-) diff --git a/src/annotate/celltypist/config.vsh.yaml b/src/annotate/celltypist/config.vsh.yaml index 6839b9f78d8..b9d3d72cb48 100644 --- a/src/annotate/celltypist/config.vsh.yaml +++ b/src/annotate/celltypist/config.vsh.yaml @@ -38,9 +38,9 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference diff --git a/src/annotate/celltypist/script.py b/src/annotate/celltypist/script.py index cfd4fa01aef..13fbbb66885 100644 --- a/src/annotate/celltypist/script.py +++ b/src/annotate/celltypist/script.py @@ -58,7 +58,7 @@ def main(par): # Provide correct format of query data for celltypist annotation ## Sanitize gene names and set as index input_modality = set_var_index( - input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + input_modality, par["input_var_gene_names"], par["sanitize_ensembl_ids"] ) ## Fetch lognormalized counts lognorm_counts = ( @@ -88,7 +88,7 @@ def main(par): reference_modality = set_var_index( reference_modality, par["reference_var_gene_names"], - par["sanitize_gene_names"], + par["sanitize_ensembl_ids"], ) # Ensure enough overlap between genes in query and reference diff --git a/src/annotate/onclass/config.vsh.yaml b/src/annotate/onclass/config.vsh.yaml index e20e9336bd2..77b415f418f 100644 --- a/src/annotate/onclass/config.vsh.yaml +++ b/src/annotate/onclass/config.vsh.yaml @@ -39,9 +39,9 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Ontology diff --git a/src/annotate/onclass/script.py b/src/annotate/onclass/script.py index a74f00a9760..73d35139049 100644 --- a/src/annotate/onclass/script.py +++ b/src/annotate/onclass/script.py @@ -124,7 +124,7 @@ def main(): # Set var names to the desired gene name format (gene symbol, ensembl id, etc.) input_modality = set_var_index( - input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + input_modality, par["input_var_gene_names"], par["sanitize_ensembl_ids"] ) input_matrix = ( input_modality.layers[par["input_layer"]] @@ -160,7 +160,7 @@ def main(): reference_modality = set_var_index( reference_modality, par["reference_var_gene_names"], - par["sanitize_gene_names"], + par["sanitize_ensembl_ids"], ) # subset to HVG if required diff --git a/src/annotate/onclass/test.py b/src/annotate/onclass/test.py index 905b6b7c6b5..bf9d7826d13 100644 --- a/src/annotate/onclass/test.py +++ b/src/annotate/onclass/test.py @@ -145,7 +145,7 @@ def test_pretrained_model(run_component, random_h5mu_path): input_file, "--input_var_gene_names", "gene_symbol", - "--sanitize_gene_names", + "--sanitize_ensembl_ids", "False", "--cl_nlp_emb_file", cl_nlp_emb_file, diff --git a/src/annotate/random_forest_annotation/config.vsh.yaml b/src/annotate/random_forest_annotation/config.vsh.yaml index 7b2f3784acb..9d2ed5b9f31 100644 --- a/src/annotate/random_forest_annotation/config.vsh.yaml +++ b/src/annotate/random_forest_annotation/config.vsh.yaml @@ -35,9 +35,9 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference diff --git a/src/annotate/random_forest_annotation/script.py b/src/annotate/random_forest_annotation/script.py index 5b36a88eede..31671b4b781 100644 --- a/src/annotate/random_forest_annotation/script.py +++ b/src/annotate/random_forest_annotation/script.py @@ -48,7 +48,7 @@ def main(): input_adata = input_mudata.mod[par["modality"]] input_modality = input_adata.copy() input_modality = set_var_index( - input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + input_modality, par["input_var_gene_names"], par["sanitize_ensembl_ids"] ) # Handle max_features parameter @@ -104,7 +104,7 @@ def main(): reference_modality = set_var_index( reference_modality, par["reference_var_gene_names"], - par["sanitize_gene_names"], + par["sanitize_ensembl_ids"], ) # subset to HVG if required diff --git a/src/annotate/scanvi/config.vsh.yaml b/src/annotate/scanvi/config.vsh.yaml index 25d3598aa85..6b7077b7036 100644 --- a/src/annotate/scanvi/config.vsh.yaml +++ b/src/annotate/scanvi/config.vsh.yaml @@ -48,9 +48,9 @@ argument_groups: default: "Unknown" description: | Value in the --obs_labels field that indicates unlabeled observations - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: scVI Model diff --git a/src/annotate/scanvi/script.py b/src/annotate/scanvi/script.py index 684cf7973b4..d9c3987a9c8 100644 --- a/src/annotate/scanvi/script.py +++ b/src/annotate/scanvi/script.py @@ -51,7 +51,7 @@ def main(): # Sanitize gene names and set as index of the AnnData object adata_subset = set_var_index( - adata_subset, par["var_gene_names"], par["sanitize_gene_names"] + adata_subset, par["var_gene_names"], par["sanitize_ensembl_ids"] ) logger.info(f"Loading pre-trained scVI model from {par['scvi_model']}") diff --git a/src/annotate/singler/config.vsh.yaml b/src/annotate/singler/config.vsh.yaml index c3a4ea8ee45..337bae6f8c9 100644 --- a/src/annotate/singler/config.vsh.yaml +++ b/src/annotate/singler/config.vsh.yaml @@ -121,7 +121,10 @@ argument_groups: If set to True, an additional output .obs field `--output_obs_pruned_predictions` will be added to the `--output`, containing labels where 'low-quality' labels are replaced with NA's. Labels are considered 'low-quality' when their delta score (stored in `--output_obs_delta_next`) fall more than 3 median absolute deviations below the median for that label type. - + - name: "--sanitize_ensembl_ids" + type: boolean + description: Whether to sanitize ensembl ids by removing version numbers. + default: true - name: Outputs description: Output arguments. arguments: diff --git a/src/annotate/singler/script.R b/src/annotate/singler/script.R index a46b59d22f2..fbd3ad8c1c3 100644 --- a/src/annotate/singler/script.R +++ b/src/annotate/singler/script.R @@ -6,7 +6,7 @@ mudata <- reticulate::import("mudata") ### VIASH START par <- list( - input = "pbmc_1k_protein_v3_filtered_feature_bc_matrix.h5mu", + input = "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_filtered_feature_bc_matrix.h5mu", modality = "rna", input_layer = NULL, input_var_gene_names = "gene_symbol", @@ -16,6 +16,7 @@ par <- list( reference_layer = NULL, reference_var_input = NULL, reference_var_gene_names = NULL, + # reference_var_gene_names = "ensemblid", reference_obs_target = "cell_ontology_class", output = "singler_output.h5mu", output_compression = "gzip", @@ -61,21 +62,31 @@ get_layer <- function(adata, layer, var_gene_names) { } # Set matrix dimnames - input_gene_names <- sanitize_gene_names(adata, var_gene_names) + input_gene_names <- sanitize_ensembl_ids(adata, var_gene_names) dimnames(data) <- list(adata$obs_names, input_gene_names) # return output data } -sanitize_gene_names <- function(adata, gene_symbol = NULL) { +sanitize_ensembl_ids <- function(adata, gene_symbol = NULL) { if (is.null(gene_symbol)) { gene_names <- adata$var_names } else { gene_names <- adata$var[[gene_symbol]] } - # Remove version numbers (dot followed by digits at end of string) - sanitized <- gsub("\\.[0-9]+$", "", gene_names) + + # Pattern matches Ensembl IDs: starts with ENS, followed by any characters, + # then an eleven digit number, optionally followed by .version_number + ensembl_pattern <- "^(ENS.*\\d{11})(?:\\.\\d+)?$" + + # Remove version numbers for ensembl ids only + sanitized <- ifelse( + grepl(ensembl_pattern, gene_names, perl = TRUE), + gsub(ensembl_pattern, "\\1", gene_names, perl = TRUE), + as.character(gene_names) + ) + sanitized } diff --git a/src/annotate/singler/test.py b/src/annotate/singler/test.py index bbc5a65d549..c6f4be38e0d 100644 --- a/src/annotate/singler/test.py +++ b/src/annotate/singler/test.py @@ -79,14 +79,14 @@ def test_params(run_component, random_h5mu_path): [ "--input", input_file, - "--input_var_gene_names", - "gene_symbol", "--reference", reference_file, "--reference_obs_target", "cell_ontology_class", "--input_reference_gene_overlap", "1000", + "--reference_var_gene_names", + "ensemblid", "--reference_var_input", "highly_variable", "de_n_genes", diff --git a/src/annotate/svm_annotation/config.vsh.yaml b/src/annotate/svm_annotation/config.vsh.yaml index 5c5d67b4618..bc90e5d2921 100644 --- a/src/annotate/svm_annotation/config.vsh.yaml +++ b/src/annotate/svm_annotation/config.vsh.yaml @@ -35,9 +35,9 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference diff --git a/src/annotate/svm_annotation/script.py b/src/annotate/svm_annotation/script.py index 74ce56e6797..a81a60d1558 100644 --- a/src/annotate/svm_annotation/script.py +++ b/src/annotate/svm_annotation/script.py @@ -52,7 +52,7 @@ def main(): input_adata = input_mudata.mod[par["modality"]] input_modality = input_adata.copy() input_modality = set_var_index( - input_modality, par["input_var_gene_names"], par["sanitize_gene_names"] + input_modality, par["input_var_gene_names"], par["sanitize_ensembl_ids"] ) if par["model"]: @@ -86,7 +86,7 @@ def main(): reference_modality = set_var_index( reference_modality, par["reference_var_gene_names"], - par["sanitize_gene_names"], + par["sanitize_ensembl_ids"], ) # subset to HVG if required diff --git a/src/integrate/scarches/config.vsh.yaml b/src/integrate/scarches/config.vsh.yaml index cdc29dbceb0..fe815c75524 100644 --- a/src/integrate/scarches/config.vsh.yaml +++ b/src/integrate/scarches/config.vsh.yaml @@ -66,9 +66,9 @@ argument_groups: (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for. Important: the order of the continuous covariates matters and should match the order of the covariates in the trained reference model. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference diff --git a/src/integrate/scarches/script.py b/src/integrate/scarches/script.py index f16453dec04..7d4038ff3e0 100644 --- a/src/integrate/scarches/script.py +++ b/src/integrate/scarches/script.py @@ -134,7 +134,7 @@ def _align_query_with_registry(adata_query, model_path): # Sanitize gene names and set as index of the AnnData object # all scArches VAE models expect gene names to be in the .var index adata_query = set_var_index( - adata_query, par["input_var_gene_names"], par["sanitize_gene_names"] + adata_query, par["input_var_gene_names"], par["sanitize_ensembl_ids"] ) # align layer diff --git a/src/integrate/scvi/config.vsh.yaml b/src/integrate/scvi/config.vsh.yaml index 6b62d3b5244..8dda248a4c1 100644 --- a/src/integrate/scvi/config.vsh.yaml +++ b/src/integrate/scvi/config.vsh.yaml @@ -73,9 +73,9 @@ argument_groups: addition to the batch covariate and are also treated as nuisance factors (i.e., the model tries to minimize their effects on the latent space). Thus, these should not be used for biologically-relevant factors that you do _not_ want to correct for. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Outputs arguments: diff --git a/src/integrate/scvi/script.py b/src/integrate/scvi/script.py index f82f060e3fc..71ca0f1f924 100644 --- a/src/integrate/scvi/script.py +++ b/src/integrate/scvi/script.py @@ -85,7 +85,7 @@ def main(): # Sanitize gene names and set as index of the AnnData object adata_subset = set_var_index( - adata_subset, par["var_gene_names"], par["sanitize_gene_names"] + adata_subset, par["var_gene_names"], par["sanitize_ensembl_ids"] ) check_validity_anndata( diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 3fbdeaedf51..3d7669fc1d6 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -27,7 +27,7 @@ def strip_version_number(gene_names: list[str]) -> list[str]: def set_var_index( - adata: ad.AnnData, var_name: str | None = None, sanitize_gene_names: bool = True + adata: ad.AnnData, var_name: str | None = None, sanitize_ensembl_ids: bool = True ) -> ad.AnnData: """Sanitize gene names (optional) and set the index of the .var DataFrame. @@ -37,7 +37,7 @@ def set_var_index( Annotated data object var_name : str | None Name of the column in `adata.var` that contains the gene names, if None, the existing index will be sanitized but not replaced. - sanitize_gene_names : bool + sanitize_ensembl_ids : bool Whether to sanitize gene names by removing version numbers. Returns @@ -47,7 +47,7 @@ def set_var_index( """ gene_names = adata.var[var_name] if var_name else adata.var.index - if sanitize_gene_names: + if sanitize_ensembl_ids: ori_gene_names = len(gene_names) gene_names = strip_version_number(gene_names) sanitized_gene_names = len(set(gene_names)) @@ -56,7 +56,7 @@ def set_var_index( "Sanitizing gene names resulted in duplicated gene names.\n" "Please ensure unique gene names before proceeding.\n" "Please make sure --var_gene_names contains ensembl IDs (not gene symbols) " - "when --sanitize_gene_names is set to True." + "when --sanitize_ensembl_ids is set to True." ) adata.var.index = gene_names diff --git a/src/workflows/annotation/celltypist/config.vsh.yaml b/src/workflows/annotation/celltypist/config.vsh.yaml index 90e4519d628..9d1e60a3578 100644 --- a/src/workflows/annotation/celltypist/config.vsh.yaml +++ b/src/workflows/annotation/celltypist/config.vsh.yaml @@ -43,9 +43,9 @@ argument_groups: min: 1 description: | The minimum number of genes present in both the reference and query datasets. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference diff --git a/src/workflows/annotation/celltypist/main.nf b/src/workflows/annotation/celltypist/main.nf index bd5cf7ad62a..aeeb17e704b 100644 --- a/src/workflows/annotation/celltypist/main.nf +++ b/src/workflows/annotation/celltypist/main.nf @@ -115,7 +115,7 @@ workflow run_wf { "output": state.output, "output_obs_predictions": state.output_obs_predictions, "output_obs_probability": state.output_obs_probability, - "sanitize_gene_names": state.sanitize_gene_names + "sanitize_ensembl_ids": state.sanitize_ensembl_ids ]}, args: [ "input_layer": "log_normalized_10k", diff --git a/src/workflows/annotation/scanvi_scarches/config.vsh.yaml b/src/workflows/annotation/scanvi_scarches/config.vsh.yaml index ef36879d02d..98fdaa84a83 100644 --- a/src/workflows/annotation/scanvi_scarches/config.vsh.yaml +++ b/src/workflows/annotation/scanvi_scarches/config.vsh.yaml @@ -71,9 +71,9 @@ argument_groups: type: string required: false description: ".var column containing gene names. By default, use the index." - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference input diff --git a/src/workflows/annotation/scanvi_scarches/main.nf b/src/workflows/annotation/scanvi_scarches/main.nf index 62738585da3..d9d4312945b 100644 --- a/src/workflows/annotation/scanvi_scarches/main.nf +++ b/src/workflows/annotation/scanvi_scarches/main.nf @@ -31,7 +31,7 @@ workflow run_wf { "reduce_lr_on_plateau": "reduce_lr_on_plateau", "lr_factor": "lr_factor", "lr_patience": "lr_patience", - "sanitize_gene_names": "sanitize_gene_names" + "sanitize_ensembl_ids": "sanitize_ensembl_ids" ], args: [ "obsm_output": "X_integrated_scvi" @@ -64,7 +64,7 @@ workflow run_wf { "reduce_lr_on_plateau": "reduce_lr_on_plateau", "lr_factor": "lr_factor", "lr_patience": "lr_patience", - "sanitize_gene_names": "sanitize_gene_names" + "sanitize_ensembl_ids": "sanitize_ensembl_ids" ], toState: [ "reference": "output", @@ -97,7 +97,7 @@ workflow run_wf { "lr_patience": "lr_patience", "output": "workflow_output", "model_output": "workflow_output_model", - "sanitize_gene_names": "sanitize_gene_names" + "sanitize_ensembl_ids": "sanitize_ensembl_ids" ], toState: [ "input": "output", diff --git a/src/workflows/annotation/scvi_knn/config.vsh.yaml b/src/workflows/annotation/scvi_knn/config.vsh.yaml index 999068a840c..76b247d46a5 100644 --- a/src/workflows/annotation/scvi_knn/config.vsh.yaml +++ b/src/workflows/annotation/scvi_knn/config.vsh.yaml @@ -62,9 +62,9 @@ argument_groups: - name: "--overwrite_existing_key" type: boolean_true description: If provided, will overwrite existing fields in the input dataset when data are copied during the reference alignment process. - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: Reference input diff --git a/src/workflows/annotation/scvi_knn/main.nf b/src/workflows/annotation/scvi_knn/main.nf index 2eb6fade159..a4c8cc5783c 100644 --- a/src/workflows/annotation/scvi_knn/main.nf +++ b/src/workflows/annotation/scvi_knn/main.nf @@ -135,7 +135,7 @@ workflow run_wf { "reduce_lr_on_plateau": state.scvi_reduce_lr_on_plateau, "lr_factor": state.scvi_lr_factor, "lr_patience": state.scvi_lr_patience, - "sanitize_gene_names": state.sanitize_gene_names + "sanitize_ensembl_ids": state.sanitize_ensembl_ids ]}, args: [ "var_input": "_common_hvg", diff --git a/src/workflows/integration/scvi_leiden/config.vsh.yaml b/src/workflows/integration/scvi_leiden/config.vsh.yaml index bfab3b968d0..a64eed1e12c 100644 --- a/src/workflows/integration/scvi_leiden/config.vsh.yaml +++ b/src/workflows/integration/scvi_leiden/config.vsh.yaml @@ -29,9 +29,9 @@ argument_groups: type: string default: "rna" required: false - - name: "--sanitize_gene_names" + - name: "--sanitize_ensembl_ids" type: boolean - description: Whether to sanitize gene names by removing version numbers. Recommended when using ENSEMBL ids. + description: Whether to sanitize ensembl ids by removing version numbers. default: true - name: "Outputs" diff --git a/src/workflows/integration/scvi_leiden/main.nf b/src/workflows/integration/scvi_leiden/main.nf index 6b6c39144d6..e6a7e9e4a6f 100644 --- a/src/workflows/integration/scvi_leiden/main.nf +++ b/src/workflows/integration/scvi_leiden/main.nf @@ -28,7 +28,7 @@ workflow run_wf { "output_model": "output_model", "modality": "modality", "input_layer": "layer", - "sanitize_gene_names": "sanitize_gene_names" + "sanitize_ensembl_ids": "sanitize_ensembl_ids" ], toState: [ "input": "output", From 6eda0632f982704c1a9b055c2fe6731564ff7145 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Thu, 30 Oct 2025 15:55:37 +0100 Subject: [PATCH 21/25] update parameter name --- src/annotate/singler/script.R | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/annotate/singler/script.R b/src/annotate/singler/script.R index fbd3ad8c1c3..e85c87e7572 100644 --- a/src/annotate/singler/script.R +++ b/src/annotate/singler/script.R @@ -6,7 +6,7 @@ mudata <- reticulate::import("mudata") ### VIASH START par <- list( - input = "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_filtered_feature_bc_matrix.h5mu", + input = "pbmc_1k_protein_v3_filtered_feature_bc_matrix.h5mu", modality = "rna", input_layer = NULL, input_var_gene_names = "gene_symbol", @@ -16,12 +16,11 @@ par <- list( reference_layer = NULL, reference_var_input = NULL, reference_var_gene_names = NULL, - # reference_var_gene_names = "ensemblid", reference_obs_target = "cell_ontology_class", output = "singler_output.h5mu", output_compression = "gzip", output_obs_predictions = "singler_labels", - output_obs_probability = "singlr_proba", + output_obs_probability = "singler_proba", output_obsm_scores = "single_r_scores", output_obs_delta_next = "singler_delta_next", output_obs_pruned_predictions = "singler_pruned_labels", From fd015e4da21d903f10ff9c7f4d6870b8b0abafd6 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 31 Oct 2025 10:12:23 +0100 Subject: [PATCH 22/25] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3d085485df..ad7f17f98ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ * `integrate/scarches` and `workflows/annotate/scanvi_scarches`: Enable correction for technical variability by multiple continuous and categorical covariates. -* Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Make feature name sanitation optional (PR #1084). +* Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Perform optional ensembl id sanitation (by stripping the version number) using the `--sanitize_ensembl_ids` argument (PR #1084). * `genetic_demux/scsplit`: bump python to `3.13` and unpin pandas and numpy (were pinned to `<2.0` and `<2` respectively) (PR #1096). From fbfd642a19f5a4c70f9461419745301354904fa2 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 31 Oct 2025 10:58:45 +0100 Subject: [PATCH 23/25] update regex matching --- src/utils/set_var_index.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 3d7669fc1d6..7e2d8e064e4 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -1,5 +1,4 @@ import anndata as ad -import re def strip_version_number(gene_names: list[str]) -> list[str]: @@ -18,12 +17,15 @@ def strip_version_number(gene_names: list[str]) -> list[str]: # Pattern matches Ensembl IDs: starts with ENS, followed by any characters, # then an eleven digit number, optionally followed by .version_number - ensembl_pattern = re.compile(r"^(ENS.*\d{11})(?:\.\d+)?$") + gene_series = gene_names.to_series() + ensembl_pattern = r"^ENS.*\d{11}(?:\.\d+)?$" + ensembl_mask = gene_series.str.match(ensembl_pattern) - return [ - match.group(1) if (match := ensembl_pattern.match(gene)) else gene - for gene in gene_names - ] + sanitized = gene_series.where( + ~ensembl_mask, gene_series.str.extract(ensembl_pattern)[0] + ) + + return sanitized def set_var_index( From ceb7ae86307c549d3b732080215a59d1befbd3f6 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 31 Oct 2025 11:59:42 +0100 Subject: [PATCH 24/25] fixup --- CHANGELOG.md | 2 +- src/utils/set_var_index.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad7f17f98ba..2a3622af7e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,7 +38,7 @@ * `integrate/scarches` and `workflows/annotate/scanvi_scarches`: Enable correction for technical variability by multiple continuous and categorical covariates. -* Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Perform optional ensembl id sanitation (by stripping the version number) using the `--sanitize_ensembl_ids` argument (PR #1084). +* Various components and workflows in `integrate`, `annotate`, `workflows/integration` and `workflows/annotation`: Optionally disable ensembl id sanitation (by stripping the version number) using the `--sanitize_ensembl_ids` argument (PR #1084). * `genetic_demux/scsplit`: bump python to `3.13` and unpin pandas and numpy (were pinned to `<2.0` and `<2` respectively) (PR #1096). diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 7e2d8e064e4..63cfdf61b20 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -1,12 +1,12 @@ import anndata as ad -def strip_version_number(gene_names: list[str]) -> list[str]: +def strip_version_number(gene_series: list[str]) -> list[str]: """Sanitize ensemble ID's by removing version numbers. Parameters ---------- - gene_names : list[str] + gene_series : list[str] List of ensemble ID's to sanitize. Returns @@ -17,8 +17,7 @@ def strip_version_number(gene_names: list[str]) -> list[str]: # Pattern matches Ensembl IDs: starts with ENS, followed by any characters, # then an eleven digit number, optionally followed by .version_number - gene_series = gene_names.to_series() - ensembl_pattern = r"^ENS.*\d{11}(?:\.\d+)?$" + ensembl_pattern = r"^(ENS.*\d{11})(?:\.\d+)?$" ensembl_mask = gene_series.str.match(ensembl_pattern) sanitized = gene_series.where( @@ -47,7 +46,7 @@ def set_var_index( AnnData Copy of `adata` with optionally sanitized and replaced index """ - gene_names = adata.var[var_name] if var_name else adata.var.index + gene_names = adata.var[var_name] if var_name else adata.var.index.to_series() if sanitize_ensembl_ids: ori_gene_names = len(gene_names) From d1f3a2df098205816897aab69b5132eafd6787b1 Mon Sep 17 00:00:00 2001 From: dorien-er Date: Fri, 31 Oct 2025 14:39:03 +0100 Subject: [PATCH 25/25] handle categorical feature series --- src/utils/set_var_index.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/utils/set_var_index.py b/src/utils/set_var_index.py index 63cfdf61b20..e33664b5aef 100644 --- a/src/utils/set_var_index.py +++ b/src/utils/set_var_index.py @@ -15,6 +15,9 @@ def strip_version_number(gene_series: list[str]) -> list[str]: List of sanitized ensemble ID's. """ + # Convert to string type to handle Categorical series + gene_series = gene_series.astype(str) + # Pattern matches Ensembl IDs: starts with ENS, followed by any characters, # then an eleven digit number, optionally followed by .version_number ensembl_pattern = r"^(ENS.*\d{11})(?:\.\d+)?$"