From c3ea543d2b4ee2d48c66e4c54ca124ff4305dd71 Mon Sep 17 00:00:00 2001 From: vicpaton Date: Tue, 7 May 2024 14:52:34 +0200 Subject: [PATCH] switched from R-deseq2 to python-deseq2 --- networkcommons/datasets.py | 86 +++++++++------------------- networkcommons/test/test_datasets.py | 15 ++--- 2 files changed, 36 insertions(+), 65 deletions(-) diff --git a/networkcommons/datasets.py b/networkcommons/datasets.py index a57ae88..703c854 100644 --- a/networkcommons/datasets.py +++ b/networkcommons/datasets.py @@ -11,6 +11,9 @@ from rpy2.robjects.packages import importr from rpy2.robjects.conversion import localconverter import decoupler as dc +from pydeseq2.dds import DeseqDataSet +from pydeseq2.default_inference import DefaultInference +from pydeseq2.ds import DeseqStats def get_available_datasets(): public_link="https://oc.embl.de/index.php/s/6KsHfeoqJOKLF6B" @@ -54,71 +57,38 @@ def download_url(url, save_path, chunk_size=128): fd.write(chunk) -def deseq2_analysis(counts, - metadata, - covariates="", - deseq2_test='Wald', - deseq2_fitType='parametric', - deseq2_betaprior=False, - deseq2_quiet=False, - deseq2_minReplicatesForReplace=7, - ): - """ - Perform DESeq2 analysis using rpy2. - - Parameters: - counts (DataFrame): A pandas DataFrame containing raw count data. - metadata (DataFrame): A pandas DataFrame containing metadata. - additional_args (dict): Additional arguments for DESeq2 analysis. - - Returns: - DESeq2 results as a DataFrame. - """ - # Importing required R packages - DESeq2 = importr("DESeq2") - base = importr("base") - - # Set genesymbol as rownames - counts.set_index('gene_symbol', inplace=True) - metadata.set_index('sample_ID', inplace=True) - - # Convert pandas DataFrames to R DataFrames - pandas2ri.activate() - gene_counts = pandas2ri.py2rpy(counts) - metadata_r = pandas2ri.py2rpy(metadata) - - if covariates != "" and len(covariates)>=1: - covariates = ["" + covariates] - # Create design formula - design_formula = robjects.Formula("~ 0 + group" + " + ".join(covariates)) +def run_deseq2_analysis(counts, + metadata, + test_group, + ref_group, + covariates=[]): + counts.set_index('gene_symbol', inplace=True) + metadata.set_index('sample_ID', inplace=True) - # Create DESeqDataSet object - formatted_data = DESeq2.DESeqDataSetFromMatrix(countData=gene_counts, - colData=metadata_r, - design=design_formula) - - # Get study groups - studygroups = list(set(metadata['group'])) + design_factors = ['group'] + if len(covariates) > 0: + if isinstance(covariates, str): + covariates = [covariates] + design_factors += covariates + + inference = DefaultInference(n_cpus=8) + dds = DeseqDataSet( + counts=counts.T, + metadata=metadata, + design_factors=design_factors, + refit_cooks=True, + inference=inference + ) + dds.deseq2() - # Run DESeq2 analysis - results = DESeq2.DESeq(formatted_data, - test=deseq2_test, - fitType=deseq2_fitType, - betaPrior=deseq2_betaprior, - quiet=deseq2_quiet, - minReplicatesForReplace=deseq2_minReplicatesForReplace) - results = DESeq2.results(results, contrast=robjects.StrVector(['group', studygroups[0], studygroups[1]])) - results = base.as_data_frame(results) + results = DeseqStats(dds, contrast=["group", test_group, ref_group], inference=inference) + results.summary() + return results.results_df.astype('float64') - # Convert DESeq2 results to pandas DataFrame - with localconverter(robjects.default_converter + pandas2ri.converter): - results_df = robjects.conversion.rpy2py(results) - - return results_df diff --git a/networkcommons/test/test_datasets.py b/networkcommons/test/test_datasets.py index 27cf619..86e852f 100644 --- a/networkcommons/test/test_datasets.py +++ b/networkcommons/test/test_datasets.py @@ -41,7 +41,7 @@ def test_deseq2_analysis(): }) # Call the deseq2_analysis function - result = deseq2_analysis(counts, metadata) + result = run_deseq2_analysis(counts, metadata, ref_group='Control', test_group='Treatment') # Assert that the returned value is a pandas DataFrame assert isinstance(result, pd.DataFrame) @@ -55,13 +55,14 @@ def test_deseq2_analysis(): # Assert that the DataFrame has the expected content data = { - 'baseMean': [93.233027, 101.285704, 11.793541], - 'log2FoldChange': [-0.218172, 0.682183, 0.052954], - 'lfcSE': [0.328036, 0.352393, 0.521659], - 'stat': [-0.665087, 1.935862, 0.101510], - 'pvalue': [0.505995, 0.052885, 0.919146], - 'padj': [0.758992, 0.158654, 0.919146] + 'baseMean': [93.233032, 101.285698, 11.793541], + 'log2FoldChange': [0.222414, -0.682183, -0.052951], + 'lfcSE': [0.150059, 0.352411, 0.521689], + 'stat': [1.482173, -1.935763, -0.101499], + 'pvalue': [0.138294, 0.052897, 0.919154], + 'padj': [0.207441, 0.158690, 0.919154] } expected_result = pd.DataFrame(data, index=['Gene1', 'Gene2', 'Gene3']) + expected_result.index.name = 'gene_symbol' pd.testing.assert_frame_equal(result, expected_result, check_exact=False)