diff --git a/docs/tutorials/3-finetune.html b/docs/tutorials/3-finetune.html new file mode 100644 index 0000000..91e1626 --- /dev/null +++ b/docs/tutorials/3-finetune.html @@ -0,0 +1,10657 @@ + + + + + +3-finetune + + + + + + + + + + + + +
+
+ +
+ +
+ +
+
+ +
+
+ +
+ + +
+
+ +
+ + +
+
+ +
+ + +
+
+ +
+ + +
+
+ +
+
+ +
+ + +
+ + +
+
+ +
+
+ +
+ + +
+ + +
+
+ +
+ +
+ + +
+
+ +
+ + +
+ + +
+
+ +
+
+ +
+
+ +
+ + +
+
+ +
+ + +
+
+ +
+ +
+
+ +
+ + +
+ + +
+ + +
+
+ +
+
+ +
+
+ +
+ +
+
+ +
+
+ +
+ +
+ + +
+
+ +
+ +
+
+ +
+ +
+ + +
+ +
+ +
+
+ +
+
+ +
+ + +
+ + +
+ + +
+
+ +
+ +
+ + +
+
+ +
+ + +
+ + +
+
+ + diff --git a/docs/tutorials/3-finetune.ipynb b/docs/tutorials/3-finetune.ipynb index 7e64289..2c55b93 100644 --- a/docs/tutorials/3-finetune.ipynb +++ b/docs/tutorials/3-finetune.ipynb @@ -15,6 +15,8 @@ "metadata": {}, "outputs": [], "source": [ + "import glob\n", + "import anndata\n", "import scanpy as sc\n", "import pandas as pd\n", "import bioframe as bf\n", @@ -28,8 +30,9 @@ "metadata": {}, "outputs": [], "source": [ - "outdir = \".\"\n", - "ad_file_path = os.path.join(outdir, \"data.h5ad\")\n", + "inputdir = \"./data\"\n", + "outdir = \"./example\"\n", + "ad_file_path = os.path.join(inputdir, \"data.h5ad\")\n", "h5_file_path = os.path.join(outdir, \"data.h5\")" ] }, @@ -58,9 +61,10 @@ { "data": { "text/plain": [ - "AnnData object with n_obs × n_vars = 50 × 1000\n", + "AnnData object with n_obs × n_vars = 50 × 931\n", " obs: 'cell_type', 'tissue', 'disease', 'study'\n", - " var: 'chrom', 'start', 'end', 'strand'" + " var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset'\n", + " uns: 'log1p'" ] }, "execution_count": 3, @@ -69,7 +73,7 @@ } ], "source": [ - "ad = sc.read(\"data/test_data.h5ad\")\n", + "ad = sc.read(ad_file_path)\n", "ad" ] }, @@ -78,7 +82,9 @@ "id": "dcb6c9a7-5e97-46fc-a2d3-fe029821c375", "metadata": {}, "source": [ - "`.obs` should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts." + "`.obs` should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts. \n", + "\n", + "Note that the original Decima model does NOT separate pseudobulks by sample, i.e. different samples from the same cell type, tissue, disease and study were merged. We also recommend filtering out pseudobulks with few cells or low read count. " ] }, { @@ -213,55 +219,98 @@ " start\n", " end\n", " strand\n", + " gene_start\n", + " gene_end\n", + " gene_length\n", + " gene_mask_start\n", + " gene_mask_end\n", + " dataset\n", " \n", " \n", " \n", " \n", " gene_0\n", " chr1\n", - " 28648600\n", - " 28648730\n", + " 26846360\n", + " 27370648\n", " +\n", + " 27010200\n", + " 27534488\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_1\n", " chr19\n", - " 39341773\n", - " 39341945\n", + " 40619897\n", + " 41144185\n", " -\n", + " 40456057\n", + " 40980345\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_2\n", " chr1\n", - " 78004346\n", - " 78004554\n", + " 79282506\n", + " 79806794\n", " -\n", + " 79118666\n", + " 79642954\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_3\n", " chr8\n", - " 143290399\n", - " 143290621\n", + " 144568573\n", + " 145092861\n", " -\n", + " 144404733\n", + " 144929021\n", + " 524288\n", + " 163840\n", + " 524288\n", + " val\n", " \n", " \n", " gene_4\n", " chr16\n", - " 1971655\n", - " 1971896\n", + " 3249848\n", + " 3774136\n", " -\n", + " 3086008\n", + " 3610296\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", "\n", "" ], "text/plain": [ - " chrom start end strand\n", - "gene_0 chr1 28648600 28648730 +\n", - "gene_1 chr19 39341773 39341945 -\n", - "gene_2 chr1 78004346 78004554 -\n", - "gene_3 chr8 143290399 143290621 -\n", - "gene_4 chr16 1971655 1971896 -" + " chrom start end strand gene_start gene_end \\\n", + "gene_0 chr1 26846360 27370648 + 27010200 27534488 \n", + "gene_1 chr19 40619897 41144185 - 40456057 40980345 \n", + "gene_2 chr1 79282506 79806794 - 79118666 79642954 \n", + "gene_3 chr8 144568573 145092861 - 144404733 144929021 \n", + "gene_4 chr16 3249848 3774136 - 3086008 3610296 \n", + "\n", + " gene_length gene_mask_start gene_mask_end dataset \n", + "gene_0 524288 163840 524288 train \n", + "gene_1 524288 163840 524288 train \n", + "gene_2 524288 163840 524288 train \n", + "gene_3 524288 163840 524288 val \n", + "gene_4 524288 163840 524288 train " ] }, "execution_count": 5, @@ -290,11 +339,12 @@ { "data": { "text/plain": [ - "array([[ 0, 36, 82, 0, 53],\n", - " [29, 84, 0, 33, 27],\n", - " [12, 33, 24, 60, 57],\n", - " [32, 0, 51, 77, 42],\n", - " [37, 2, 0, 0, 80]])" + "array([[0. , 7.2824097, 7.2824097, 0. , 7.2824097],\n", + " [7.3014727, 7.3014727, 0. , 7.3014727, 7.3014727],\n", + " [7.2867765, 7.2867765, 7.2867765, 7.2867765, 7.2867765],\n", + " [7.283863 , 0. , 7.283863 , 7.283863 , 7.283863 ],\n", + " [7.3239307, 7.3239307, 0. , 0. , 7.3239307]],\n", + " dtype=float32)" ] }, "execution_count": 6, @@ -327,7 +377,15 @@ "execution_count": 7, "id": "34115f7a-aaf8-4ca3-abbb-a4fc552bf5a7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING: adata.X seems to be already log-transformed.\n" + ] + } + ], "source": [ "sc.pp.normalize_total(ad, target_sum=1e6)\n", "sc.pp.log1p(ad)" @@ -342,11 +400,11 @@ { "data": { "text/plain": [ - "array([[0. , 6.921574 , 7.7442207, 0. , 7.3080306],\n", - " [6.6934667, 7.756176 , 0. , 6.822528 , 6.6220994],\n", - " [5.8283887, 6.838115 , 6.5200634, 7.4354696, 7.3842077],\n", - " [6.832712 , 0. , 7.2984004, 7.7101517, 7.104389 ],\n", - " [6.996557 , 4.0946727, 0. , 0. , 7.767174 ]],\n", + "array([[0. , 7.2867765, 7.2867765, 0. , 7.2867765],\n", + " [7.305924 , 7.305924 , 0. , 7.305924 , 7.305924 ],\n", + " [7.2911625, 7.2911625, 7.2911625, 7.2911625, 7.2911625],\n", + " [7.2896986, 0. , 7.2896986, 7.2896986, 7.2896986],\n", + " [7.3284836, 7.3284836, 0. , 0. , 7.3284836]],\n", " dtype=float32)" ] }, @@ -385,8 +443,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "/opt/conda/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n" ] } ], @@ -425,55 +485,98 @@ " start\n", " end\n", " strand\n", + " gene_start\n", + " gene_end\n", + " gene_length\n", + " gene_mask_start\n", + " gene_mask_end\n", + " dataset\n", " \n", " \n", " \n", " \n", " gene_0\n", " chr1\n", - " 28648600\n", - " 28648730\n", + " 26846360\n", + " 27370648\n", " +\n", + " 27010200\n", + " 27534488\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_1\n", " chr19\n", - " 39341773\n", - " 39341945\n", + " 40619897\n", + " 41144185\n", " -\n", + " 40456057\n", + " 40980345\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_2\n", " chr1\n", - " 78004346\n", - " 78004554\n", + " 79282506\n", + " 79806794\n", " -\n", + " 79118666\n", + " 79642954\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_3\n", " chr8\n", - " 143290399\n", - " 143290621\n", + " 144568573\n", + " 145092861\n", " -\n", + " 144404733\n", + " 144929021\n", + " 524288\n", + " 163840\n", + " 524288\n", + " val\n", " \n", " \n", " gene_4\n", " chr16\n", - " 1971655\n", - " 1971896\n", + " 3249848\n", + " 3774136\n", " -\n", + " 3086008\n", + " 3610296\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", "\n", "" ], "text/plain": [ - " chrom start end strand\n", - "gene_0 chr1 28648600 28648730 +\n", - "gene_1 chr19 39341773 39341945 -\n", - "gene_2 chr1 78004346 78004554 -\n", - "gene_3 chr8 143290399 143290621 -\n", - "gene_4 chr16 1971655 1971896 -" + " chrom start end strand gene_start gene_end \\\n", + "gene_0 chr1 26846360 27370648 + 27010200 27534488 \n", + "gene_1 chr19 40619897 41144185 - 40456057 40980345 \n", + "gene_2 chr1 79282506 79806794 - 79118666 79642954 \n", + "gene_3 chr8 144568573 145092861 - 144404733 144929021 \n", + "gene_4 chr16 3249848 3774136 - 3086008 3610296 \n", + "\n", + " gene_length gene_mask_start gene_mask_end dataset \n", + "gene_0 524288 163840 524288 train \n", + "gene_1 524288 163840 524288 train \n", + "gene_2 524288 163840 524288 train \n", + "gene_3 524288 163840 524288 val \n", + "gene_4 524288 163840 524288 train " ] }, "execution_count": 10, @@ -539,70 +642,95 @@ " gene_start\n", " gene_end\n", " gene_length\n", + " gene_mask_start\n", + " gene_mask_end\n", + " dataset\n", " \n", " \n", " \n", " \n", " gene_0\n", " chr1\n", - " 28648600\n", - " 28648730\n", + " 26846360\n", + " 27370648\n", " +\n", - " 28648600\n", - " 28648730\n", - " 130\n", + " 26846360\n", + " 27370648\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_1\n", " chr19\n", - " 39341773\n", - " 39341945\n", + " 40619897\n", + " 41144185\n", " -\n", - " 39341773\n", - " 39341945\n", - " 172\n", + " 40619897\n", + " 41144185\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_2\n", " chr1\n", - " 78004346\n", - " 78004554\n", + " 79282506\n", + " 79806794\n", " -\n", - " 78004346\n", - " 78004554\n", - " 208\n", + " 79282506\n", + " 79806794\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", " gene_3\n", " chr8\n", - " 143290399\n", - " 143290621\n", + " 144568573\n", + " 145092861\n", " -\n", - " 143290399\n", - " 143290621\n", - " 222\n", + " 144568573\n", + " 145092861\n", + " 524288\n", + " 163840\n", + " 524288\n", + " val\n", " \n", " \n", " gene_4\n", " chr16\n", - " 1971655\n", - " 1971896\n", + " 3249848\n", + " 3774136\n", " -\n", - " 1971655\n", - " 1971896\n", - " 241\n", + " 3249848\n", + " 3774136\n", + " 524288\n", + " 163840\n", + " 524288\n", + " train\n", " \n", " \n", "\n", "" ], "text/plain": [ - " chrom start end strand gene_start gene_end gene_length\n", - "gene_0 chr1 28648600 28648730 + 28648600 28648730 130\n", - "gene_1 chr19 39341773 39341945 - 39341773 39341945 172\n", - "gene_2 chr1 78004346 78004554 - 78004346 78004554 208\n", - "gene_3 chr8 143290399 143290621 - 143290399 143290621 222\n", - "gene_4 chr16 1971655 1971896 - 1971655 1971896 241" + " chrom start end strand gene_start gene_end \\\n", + "gene_0 chr1 26846360 27370648 + 26846360 27370648 \n", + "gene_1 chr19 40619897 41144185 - 40619897 41144185 \n", + "gene_2 chr1 79282506 79806794 - 79282506 79806794 \n", + "gene_3 chr8 144568573 145092861 - 144568573 145092861 \n", + "gene_4 chr16 3249848 3774136 - 3249848 3774136 \n", + "\n", + " gene_length gene_mask_start gene_mask_end dataset \n", + "gene_0 524288 163840 524288 train \n", + "gene_1 524288 163840 524288 train \n", + "gene_2 524288 163840 524288 train \n", + "gene_3 524288 163840 524288 val \n", + "gene_4 524288 163840 524288 train " ] }, "execution_count": 12, @@ -633,9 +761,9 @@ "output_type": "stream", "text": [ "The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start.\n", - "2 intervals extended beyond the chromosome start and have been shifted\n", - "0 intervals extended beyond the chromosome end and have been shifted\n", - "0 intervals did not extend far enough upstream of the TSS and have been dropped\n" + "3 intervals extended beyond the chromosome start and have been shifted\n", + "2 intervals extended beyond the chromosome end and have been shifted\n", + "5 intervals did not extend far enough upstream of the TSS and have been dropped\n" ] } ], @@ -680,87 +808,93 @@ " gene_length\n", " gene_mask_start\n", " gene_mask_end\n", + " dataset\n", " \n", " \n", " \n", " \n", " gene_0\n", " chr1\n", - " 28484760\n", - " 29009048\n", + " 26682520\n", + " 27206808\n", " +\n", - " 28648600\n", - " 28648730\n", - " 130\n", + " 26846360\n", + " 27370648\n", + " 524288\n", " 163840\n", - " 163970\n", + " 524288\n", + " train\n", " \n", " \n", " gene_1\n", " chr19\n", - " 38981497\n", - " 39505785\n", + " 40783737\n", + " 41308025\n", " -\n", - " 39341773\n", - " 39341945\n", - " 172\n", + " 40619897\n", + " 41144185\n", + " 524288\n", " 163840\n", - " 164012\n", + " 524288\n", + " train\n", " \n", " \n", " gene_2\n", " chr1\n", - " 77644106\n", - " 78168394\n", + " 79446346\n", + " 79970634\n", " -\n", - " 78004346\n", - " 78004554\n", - " 208\n", + " 79282506\n", + " 79806794\n", + " 524288\n", " 163840\n", - " 164048\n", + " 524288\n", + " train\n", " \n", " \n", - " gene_3\n", - " chr8\n", - " 142930173\n", - " 143454461\n", + " gene_4\n", + " chr16\n", + " 3413688\n", + " 3937976\n", " -\n", - " 143290399\n", - " 143290621\n", - " 222\n", + " 3249848\n", + " 3774136\n", + " 524288\n", " 163840\n", - " 164062\n", + " 524288\n", + " train\n", " \n", " \n", - " gene_4\n", - " chr16\n", - " 1611448\n", - " 2135736\n", - " -\n", - " 1971655\n", - " 1971896\n", - " 241\n", + " gene_5\n", + " chr10\n", + " 22987161\n", + " 23511449\n", + " +\n", + " 23151001\n", + " 23675289\n", + " 524288\n", " 163840\n", - " 164081\n", + " 524288\n", + " train\n", " \n", " \n", "\n", "" ], "text/plain": [ - " chrom start end strand gene_start gene_end \\\n", - "gene_0 chr1 28484760 29009048 + 28648600 28648730 \n", - "gene_1 chr19 38981497 39505785 - 39341773 39341945 \n", - "gene_2 chr1 77644106 78168394 - 78004346 78004554 \n", - "gene_3 chr8 142930173 143454461 - 143290399 143290621 \n", - "gene_4 chr16 1611448 2135736 - 1971655 1971896 \n", + " chrom start end strand gene_start gene_end gene_length \\\n", + "gene_0 chr1 26682520 27206808 + 26846360 27370648 524288 \n", + "gene_1 chr19 40783737 41308025 - 40619897 41144185 524288 \n", + "gene_2 chr1 79446346 79970634 - 79282506 79806794 524288 \n", + "gene_4 chr16 3413688 3937976 - 3249848 3774136 524288 \n", + "gene_5 chr10 22987161 23511449 + 23151001 23675289 524288 \n", "\n", - " gene_length gene_mask_start gene_mask_end \n", - "gene_0 130 163840 163970 \n", - "gene_1 172 163840 164012 \n", - "gene_2 208 163840 164048 \n", - "gene_3 222 163840 164062 \n", - "gene_4 241 163840 164081 " + " gene_mask_start gene_mask_end dataset \n", + "gene_0 163840 524288 train \n", + "gene_1 163840 524288 train \n", + "gene_2 163840 524288 train \n", + "gene_4 163840 524288 train \n", + "gene_5 163840 524288 train " ] }, "execution_count": 14, @@ -946,14 +1080,14 @@ " fold0\n", " \n", " \n", - " 45\n", - " gene_3\n", - " fold4\n", + " 44\n", + " gene_4\n", + " fold0\n", " \n", " \n", - " 60\n", + " 45\n", " gene_4\n", - " fold0\n", + " fold2\n", " \n", " \n", "\n", @@ -964,8 +1098,8 @@ "0 gene_0 fold5\n", "15 gene_1 fold0\n", "30 gene_2 fold0\n", - "45 gene_3 fold4\n", - "60 gene_4 fold0" + "44 gene_4 fold0\n", + "45 gene_4 fold2" ] }, "execution_count": 16, @@ -1017,7 +1151,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/tmp/ipykernel_60192/1980240170.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual.\n" + "/tmp/ipykernel_4100753/3109841685.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual.\n" ] } ], @@ -1070,66 +1204,66 @@ " \n", " gene_0\n", " chr1\n", - " 28484760\n", - " 29009048\n", + " 26682520\n", + " 27206808\n", " +\n", - " 28648600\n", - " 28648730\n", - " 130\n", + " 26846360\n", + " 27370648\n", + " 524288\n", " 163840\n", - " 163970\n", + " 524288\n", " train\n", " \n", " \n", " gene_1\n", " chr19\n", - " 38981497\n", - " 39505785\n", + " 40783737\n", + " 41308025\n", " -\n", - " 39341773\n", - " 39341945\n", - " 172\n", + " 40619897\n", + " 41144185\n", + " 524288\n", " 163840\n", - " 164012\n", + " 524288\n", " train\n", " \n", " \n", " gene_2\n", " chr1\n", - " 77644106\n", - " 78168394\n", + " 79446346\n", + " 79970634\n", " -\n", - " 78004346\n", - " 78004554\n", - " 208\n", + " 79282506\n", + " 79806794\n", + " 524288\n", " 163840\n", - " 164048\n", + " 524288\n", " train\n", " \n", " \n", - " gene_3\n", - " chr8\n", - " 142930173\n", - " 143454461\n", + " gene_4\n", + " chr16\n", + " 3413688\n", + " 3937976\n", " -\n", - " 143290399\n", - " 143290621\n", - " 222\n", + " 3249848\n", + " 3774136\n", + " 524288\n", " 163840\n", - " 164062\n", - " val\n", + " 524288\n", + " train\n", " \n", " \n", - " gene_4\n", - " chr16\n", - " 1611448\n", - " 2135736\n", - " -\n", - " 1971655\n", - " 1971896\n", - " 241\n", + " gene_5\n", + " chr10\n", + " 22987161\n", + " 23511449\n", + " +\n", + " 23151001\n", + " 23675289\n", + " 524288\n", " 163840\n", - " 164081\n", + " 524288\n", " train\n", " \n", " \n", @@ -1137,19 +1271,19 @@ "" ], "text/plain": [ - " chrom start end strand gene_start gene_end \\\n", - "gene_0 chr1 28484760 29009048 + 28648600 28648730 \n", - "gene_1 chr19 38981497 39505785 - 39341773 39341945 \n", - "gene_2 chr1 77644106 78168394 - 78004346 78004554 \n", - "gene_3 chr8 142930173 143454461 - 143290399 143290621 \n", - "gene_4 chr16 1611448 2135736 - 1971655 1971896 \n", + " chrom start end strand gene_start gene_end gene_length \\\n", + "gene_0 chr1 26682520 27206808 + 26846360 27370648 524288 \n", + "gene_1 chr19 40783737 41308025 - 40619897 41144185 524288 \n", + "gene_2 chr1 79446346 79970634 - 79282506 79806794 524288 \n", + "gene_4 chr16 3413688 3937976 - 3249848 3774136 524288 \n", + "gene_5 chr10 22987161 23511449 + 23151001 23675289 524288 \n", "\n", - " gene_length gene_mask_start gene_mask_end dataset \n", - "gene_0 130 163840 163970 train \n", - "gene_1 172 163840 164012 train \n", - "gene_2 208 163840 164048 train \n", - "gene_3 222 163840 164062 val \n", - "gene_4 241 163840 164081 train " + " gene_mask_start gene_mask_end dataset \n", + "gene_0 163840 524288 train \n", + "gene_1 163840 524288 train \n", + "gene_2 163840 524288 train \n", + "gene_4 163840 524288 train \n", + "gene_5 163840 524288 train " ] }, "execution_count": 19, @@ -1171,9 +1305,9 @@ "data": { "text/plain": [ "dataset\n", - "train 824\n", - "test 98\n", - "val 78\n", + "train 769\n", + "test 86\n", + "val 71\n", "Name: count, dtype: int64" ] }, @@ -1258,12 +1392,12 @@ "text": [ "Writing metadata\n", "Writing task indices\n", - "Writing genes array of shape: (1000, 2)\n", - "Writing labels array of shape: (1000, 50, 1)\n", + "Writing genes array of shape: (926, 2)\n", + "Writing labels array of shape: (926, 50, 1)\n", "Making gene masks\n", - "Writing mask array of shape: (1000, 534288)\n", + "Writing mask array of shape: (926, 534288)\n", "Encoding sequences\n", - "Writing sequence array of shape: (1000, 534288)\n", + "Writing sequence array of shape: (926, 534288)\n", "Done!\n" ] } @@ -1351,10 +1485,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", - "decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", - "decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", - "decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n" + "decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", + "decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", + "decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n", + "decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data/data.h5ad --h5-file ./example/data.h5 --outdir ./example --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16\n" ] } ], @@ -1365,110 +1499,141 @@ }, { "cell_type": "markdown", - "id": "9ad995d0-3176-4f70-ac01-dc3c7b1aa556", + "id": "4133e741", "metadata": {}, "source": [ - "## Test" + "Here, we train the model for 1 epoch for quick progressing in tutorial. Run the training for more epochs in your training." ] }, { "cell_type": "code", - "execution_count": 28, - "id": "538d1250-8fc2-460b-b5fc-61ec083cca86", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlal-avantika\u001b[0m (\u001b[33mgrelu\u001b[0m) to \u001b[32mhttps://genentech.wandb.io\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" - ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import wandb\n", - "\n", - "wandb.login(host=\"https://genentech.wandb.io\", anonymous=\"never\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d741430f-9fdb-4e25-806d-1c28db345b0a", + "execution_count": 27, + "id": "d0fdaa9d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "decima - INFO - Data paths: matrix_file=./data.h5ad, h5_file=./data.h5\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", + " warnings.warn(\n", + "decima - INFO - Data paths: matrix_file=./data/data.h5ad, h5_file=./example/data.h5\n", "decima - INFO - Reading anndata\n", "decima - INFO - Making dataset objects\n", - "decima - INFO - train_params: {'name': 'finetune_test_3', 'batch_size': 4, 'num_workers': 16, 'devices': 3, 'logger': 'wandb', 'save_dir': '.', 'max_epochs': 15, 'lr': 5e-05, 'total_weight': 0.0001, 'accumulate_grad_batches': 5, 'loss': 'poisson_multinomial', 'clip': 0.0, 'save_top_k': 1, 'pin_memory': True}\n", - "decima - INFO - model_params: {'n_tasks': 50, 'init_borzoi': True, 'replicate': '3'}\n", + "decima - INFO - train_params: {'batch_size': 1, 'num_workers': 16, 'devices': 0, 'logger': 'wandb', 'save_dir': './example', 'max_epochs': 1, 'lr': 5e-05, 'total_weight': 0.0001, 'accumulate_grad_batches': 5, 'loss': 'poisson_multinomial', 'clip': 0.0, 'save_top_k': 1, 'pin_memory': True}\n", + "decima - INFO - model_params: {'n_tasks': 50, 'init_borzoi': True, 'replicate': '0'}\n", "decima - INFO - Initializing model\n", - "decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 3\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33manony-mouse-891169334544049289\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact human_state_dict_fold3:latest, 709.30MB. 1 files... \n", + "decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33manony-mouse-591272909468377997\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...\n", "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n", - "Done. 0:0:1.2 (583.7MB/s)\n", - "/opt/conda/lib/python3.11/site-packages/decima/model/decima_model.py:68: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + "Done. 00:00:01.6 (439.0MB/s)\n", "decima - INFO - Connecting to wandb.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33manony-mouse-891169334544049289\u001b[0m to \u001b[32mhttps://genentech.wandb.io\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.19.11\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1mfinetune_test_3/wandb/run-20250910_180859-8m9dcnvo\u001b[0m\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33manony-mouse-591272909468377997\u001b[0m to \u001b[32mhttps://genentech.wandb.io\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[38;5;178m⢿\u001b[0m Waiting for wandb.init()...\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[38;5;178m⣻\u001b[0m Waiting for wandb.init()...\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[38;5;178m⣽\u001b[0m setting up run bj42z19b (0.2s)\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.22.2\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Run data is saved locally in \u001b[35m\u001b[1mfinetune_test_0/wandb/run-20251028_141723-bj42z19b\u001b[0m\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb offline`\u001b[0m to turn off syncing.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mfinetune_test_3\u001b[0m\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://genentech.wandb.io/grelu/decima\u001b[0m\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://genentech.wandb.io/grelu/decima/runs/8m9dcnvo\u001b[0m\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Syncing run \u001b[33mfinetune_test_0\u001b[0m\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: ⭐️ View project at \u001b[34m\u001b[4mhttps://genentech.wandb.io/celik-muhammed_hasan/decima\u001b[0m\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: 🚀 View run at \u001b[34m\u001b[4mhttps://genentech.wandb.io/celik-muhammed_hasan/decima/runs/bj42z19b\u001b[0m\n", "decima - INFO - Training\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)\n", + "/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /home/celikm5/miniforge3/envs/decima2/bin/decima ...\n", "Using 16bit Automatic Mixed Precision (AMP)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", - "/opt/conda/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: UserWarning: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n", - "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", - "Validation DataLoader 0: 0%| | 0/20 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cell_typetissuediseasestudysize_factortrain_pearsonval_pearsontest_pearson
pseudobulk_0ct_0t_0d_0st_04976.871582-0.0160800.0389150.077958
pseudobulk_1ct_0t_0d_1st_04887.680664-0.0266680.149208-0.025352
pseudobulk_2ct_0t_0d_2st_14950.704590-0.0443640.0084720.115734
pseudobulk_3ct_0t_0d_0st_14949.7099610.021054-0.0918070.048898
pseudobulk_4ct_0t_0d_1st_24792.810547-0.046708-0.071214-0.104898
\n", + "" + ], + "text/plain": [ + " cell_type tissue disease study size_factor train_pearson \\\n", + "pseudobulk_0 ct_0 t_0 d_0 st_0 4976.871582 -0.016080 \n", + "pseudobulk_1 ct_0 t_0 d_1 st_0 4887.680664 -0.026668 \n", + "pseudobulk_2 ct_0 t_0 d_2 st_1 4950.704590 -0.044364 \n", + "pseudobulk_3 ct_0 t_0 d_0 st_1 4949.709961 0.021054 \n", + "pseudobulk_4 ct_0 t_0 d_1 st_2 4792.810547 -0.046708 \n", + "\n", + " val_pearson test_pearson \n", + "pseudobulk_0 0.038915 0.077958 \n", + "pseudobulk_1 0.149208 -0.025352 \n", + "pseudobulk_2 0.008472 0.115734 \n", + "pseudobulk_3 -0.091807 0.048898 \n", + "pseudobulk_4 -0.071214 -0.104898 " + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ad_out.obs.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "121a7787-4c74-465f-ae93-b529564cc2fa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chromstartendstrandgene_startgene_endgene_lengthgene_mask_startgene_mask_enddatasetpearsonsize_factor_pearson
gene_0chr12668252027206808+2684636027370648524288163840524288train0.306280-0.059291
gene_1chr194078373741308025-4061989741144185524288163840524288train0.014492-0.035897
gene_2chr17944634679970634-7928250679806794524288163840524288train0.1821720.226918
gene_4chr1634136883937976-32498483774136524288163840524288train0.098095-0.032441
gene_5chr102298716123511449+2315100123675289524288163840524288train0.016748-0.059998
\n", + "
" + ], + "text/plain": [ + " chrom start end strand gene_start gene_end gene_length \\\n", + "gene_0 chr1 26682520 27206808 + 26846360 27370648 524288 \n", + "gene_1 chr19 40783737 41308025 - 40619897 41144185 524288 \n", + "gene_2 chr1 79446346 79970634 - 79282506 79806794 524288 \n", + "gene_4 chr16 3413688 3937976 - 3249848 3774136 524288 \n", + "gene_5 chr10 22987161 23511449 + 23151001 23675289 524288 \n", + "\n", + " gene_mask_start gene_mask_end dataset pearson size_factor_pearson \n", + "gene_0 163840 524288 train 0.306280 -0.059291 \n", + "gene_1 163840 524288 train 0.014492 -0.035897 \n", + "gene_2 163840 524288 train 0.182172 0.226918 \n", + "gene_4 163840 524288 train 0.098095 -0.032441 \n", + "gene_5 163840 524288 train 0.016748 -0.059998 " + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ad_out.var.head()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "decima2", "language": "python", "name": "python3" }, @@ -1526,7 +3077,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.11.14" } }, "nbformat": 4, diff --git a/docs/tutorials/data/data.h5ad b/docs/tutorials/data/data.h5ad new file mode 100644 index 0000000..da3afa9 Binary files /dev/null and b/docs/tutorials/data/data.h5ad differ diff --git a/src/decima/cli/__init__.py b/src/decima/cli/__init__.py index 930f915..9ae6871 100644 --- a/src/decima/cli/__init__.py +++ b/src/decima/cli/__init__.py @@ -2,7 +2,7 @@ import click from decima.cli.predict_genes import cli_predict_genes -from decima.cli.download import cli_download +from decima.cli.download import cli_cache, cli_download_weights, cli_download_metadata, cli_download from decima.cli.attributions import ( cli_attributions, cli_attributions_plot, @@ -40,6 +40,9 @@ def main(): main.add_command(cli_predict_genes, name="predict-genes") +main.add_command(cli_cache, name="cache") +main.add_command(cli_download_weights, name="download-weights") +main.add_command(cli_download_metadata, name="download-metadata") main.add_command(cli_download, name="download") main.add_command(cli_query_cell, name="query-cell") main.add_command(cli_attributions, name="attributions") diff --git a/src/decima/cli/attributions.py b/src/decima/cli/attributions.py index 1f6c059..f2dab90 100644 --- a/src/decima/cli/attributions.py +++ b/src/decima/cli/attributions.py @@ -17,6 +17,7 @@ """ import click +from decima.cli.callback import parse_genes, parse_model, parse_attributions from decima.interpret.attributions import ( plot_attributions, predict_save_attributions, @@ -47,6 +48,7 @@ type=str, required=False, default=0, + callback=parse_model, help="Model to use for attribution analysis either replicate number or path to the model.", show_default=True, ) @@ -87,6 +89,7 @@ type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.", + callback=parse_genes, show_default=True, ) @click.option( @@ -149,16 +152,6 @@ def cli_attributions_predict( └── {output_prefix}.attributions.bigwig # Genome browser track of attribution as bigwig file obtained with averaging the attribution scores across the genes for genomics coordinates. """ - - if model in ["0", "1", "2", "3"]: # replicate index - model = int(model) - - if isinstance(device, str) and device.isdigit(): - device = int(device) - - if genes is not None: - genes = genes.split(",") - predict_save_attributions( output_prefix=output_prefix, tasks=tasks, @@ -181,7 +174,14 @@ def cli_attributions_predict( @click.command() @click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files") -@click.option("-g", "--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.") +@click.option( + "-g", + "--genes", + type=str, + required=False, + callback=parse_genes, + help="Comma-separated list of gene symbols or IDs to analyze.", +) @click.option("--seqs", type=str, required=False, help="Path to a file containing sequences to analyze") @click.option( "--tasks", @@ -197,6 +197,7 @@ def cli_attributions_predict( type=str, required=False, default="ensemble", + callback=parse_model, help="Model to use for attribution analysis either replicate number or path to the model.", show_default=True, ) @@ -288,12 +289,6 @@ def cli_attributions( >>> decima attributions -o output_prefix --seqs tests/data/seqs.fasta --tasks "cell_type == 'classical monocyte'" --device 0 """ - if model in ["0", "1", "2", "3"]: # replicate index - model = int(model) - - if isinstance(genes, str): - genes = genes.split(",") - predict_attributions_seqlet_calling( output_prefix=output_prefix, genes=genes, @@ -321,7 +316,9 @@ def cli_attributions( @click.command() @click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files") -@click.option("--attributions", type=str, required=True, help="Path to the attribution files") +@click.option( + "--attributions", type=str, callback=parse_attributions, required=True, help="Path to the attribution files" +) @click.option( "--tasks", type=str, @@ -333,7 +330,13 @@ def cli_attributions( ) @click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.") @click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") -@click.option("--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.") +@click.option( + "--genes", + type=str, + required=False, + callback=parse_genes, + help="Comma-separated list of gene symbols or IDs to analyze.", +) @click.option( "--top-n-markers", type=int, @@ -393,12 +396,6 @@ def cli_attributions_recursive_seqlet_calling( >>> decima attributions-recursive-seqlet-calling --attributions attributions_0.h5,attributions_1.h5 -o output_prefix --genes SPI1 """ - if isinstance(attributions, str): - attributions = attributions.split(",") - - if genes is not None: - genes = genes.split(",") - recursive_seqlet_calling( output_prefix=output_prefix, attributions=attributions, @@ -422,7 +419,14 @@ def cli_attributions_recursive_seqlet_calling( @click.command() @click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files") -@click.option("-g", "--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.") +@click.option( + "-g", + "--genes", + type=str, + required=False, + callback=parse_genes, + help="Comma-separated list of gene symbols or IDs to analyze.", +) @click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") @click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.") @click.option("--seqlogo-window", type=int, default=50, help="Window size for sequence logo plots") @@ -449,8 +453,6 @@ def cli_attributions_plot( >>> decima attributions-plot -o output_prefix -g SPI1 """ - genes = genes.split(",") - plot_attributions( output_prefix=output_prefix, genes=genes, diff --git a/src/decima/cli/callback.py b/src/decima/cli/callback.py new file mode 100644 index 0000000..1a24cc5 --- /dev/null +++ b/src/decima/cli/callback.py @@ -0,0 +1,57 @@ +import click +from pathlib import Path + + +def parse_model(ctx, param, value): + if value is None: + return None + elif isinstance(value, str): + if value == "ensemble": + return "ensemble" + elif value in ["0", "1", "2", "3"]: + return int(value) + + paths = value.split(",") + for path in paths: + if not Path(path).exists(): + raise click.ClickException( + f"Model path {path} does not exist. Check if the path is correct and the file exists." + ) + return paths + + return value + + +def parse_genes(ctx, param, value): + if value is None: + return None + elif isinstance(value, str): + return value.split(",") + raise ValueError(f"Invalid genes: {value}. Genes should be a comma-separated list of gene names or None.") + + +def validate_save_replicates(ctx, param, value): + if value: + if ctx.params["model"] == "ensemble": + return value + elif isinstance(ctx.params["model"], list) and (len(ctx.params["model"]) > 1): + return value + else: + raise ValueError( + "`--save-replicates` is only supported for ensemble models. Pass `ensemble` or list of models as the model argument." + ) + return value + + +def parse_attributions(ctx, param, value): + value = value.split(",") + for i in value: + if not Path(i).exists(): + raise click.ClickException( + f"Attribution path {i} does not exist. Check if the path is correct and the file exists." + ) + elif not i.endswith(".h5"): + raise click.ClickException( + f"Attribution path {i} is not a h5 file. Check if the path is correct and the file is a h5 file." + ) + return value diff --git a/src/decima/cli/download.py b/src/decima/cli/download.py index 94be9f8..a145cd1 100644 --- a/src/decima/cli/download.py +++ b/src/decima/cli/download.py @@ -6,14 +6,56 @@ `decima download` is the main command for downloading the required data and model weights. It includes subcommands for: -- Downloading the required data and model weights. `download` +- Caching the required data and model weights. `cache` """ import click -from decima.hub.download import download_decima_data +from decima.cli.callback import parse_model +from decima.hub.download import ( + cache_decima_data, + download_decima_weights, + download_decima_metadata, + download_decima, +) @click.command() -def cli_download(): - """Download all required data and model weights.""" - download_decima_data() +def cli_cache(): + """Cache all required data and model weights.""" + cache_decima_data() + + +@click.command() +@click.option( + "--model", type=str, default="ensemble", help="Model to download. Default: ensemble.", callback=parse_model +) +@click.option( + "--download-dir", + type=click.Path(), + default=".", + help="Directory to download the model weights. Default: current directory.", +) +def cli_download_weights(model, download_dir): + """Download pre-trained Decima model weights.""" + download_decima_weights(model, str(download_dir)) + + +@click.command() +@click.option( + "--download-dir", + type=click.Path(), + default=".", + help="Directory to download the metadata. Default: current directory.", +) +def cli_download_metadata(download_dir): + """Download pre-trained Decima metadata.""" + download_decima_metadata(str(download_dir)) + + +@click.command() +@click.option( + "--download-dir", type=click.Path(), default=".", help="Directory to download the data. Default: current directory." +) +def cli_download(download_dir): + """Download model weights and metadata for Decima.""" + download_decima(str(download_dir)) diff --git a/src/decima/cli/finetune.py b/src/decima/cli/finetune.py index 7a49c5f..3105797 100755 --- a/src/decima/cli/finetune.py +++ b/src/decima/cli/finetune.py @@ -64,7 +64,27 @@ def cli_finetune( num_workers, seed, ): - """Finetune the Decima model.""" + """Finetune the Decima model. + + Args: + name: Name of the run for logging and checkpointing + model: Model path or replication number (0-3) + device: Device to use for training. Default: "0" + matrix_file: Path to the matrix file containing training data + h5_file: Path to the H5 file containing sequences + outdir: Output directory path to save model checkpoints + learning_rate: Learning rate for training. Default: 0.001 + loss_total_weight: Total weight parameter for the loss function + gradient_accumulation: Number of gradient accumulation steps + batch_size: Batch size for training. Default: 1 + max_seq_shift: Maximum sequence shift for data augmentation. Default: 5000 + gradient_clipping: Gradient clipping value. Default: 0.0 (disabled) + save_top_k: Number of best checkpoints to save. Default: 1 + epochs: Number of training epochs. Default: 1 + logger: Logger type to use. Default: "wandb" + num_workers: Number of data loading workers. Default: 16 + seed: Random seed for reproducibility. Default: 0 + """ train_logger = logger logger = logging.getLogger("decima") logger.info(f"Data paths: matrix_file={matrix_file}, h5_file={h5_file}") @@ -86,7 +106,6 @@ def cli_finetune( device = int(device) train_params = { - "name": name, "batch_size": batch_size, "num_workers": num_workers, "devices": device, @@ -97,7 +116,6 @@ def cli_finetune( "total_weight": loss_total_weight, "accumulate_grad_batches": gradient_accumulation, "loss": "poisson_multinomial", - # "pairs": ad.uns["disease_pairs"].values, "clip": gradient_clipping, "save_top_k": save_top_k, "pin_memory": True, @@ -111,7 +129,7 @@ def cli_finetune( logger.info(f"model_params: {model_params}") logger.info("Initializing model") - model = LightningModel(model_params=model_params, train_params=train_params) + model = LightningModel(name=name, model_params=model_params, train_params=train_params) if train_logger == "wandb": logger.info("Connecting to wandb.") diff --git a/src/decima/cli/modisco.py b/src/decima/cli/modisco.py index a4b1afb..264f285 100644 --- a/src/decima/cli/modisco.py +++ b/src/decima/cli/modisco.py @@ -12,18 +12,16 @@ - Extracting the seqlets from the modisco results. `modisco-seqlet-bed` Examples: - >>> decima modisco -o output_prefix -t tasks -o off_tasks -m model -m metadata -m method -m transform -m batch_size -m genes -m top_n_markers -m disable_bigwig -m disable_correct_grad_bigwig -m device -m genome -m num_workers - ... + >>> decima modisco -o output_prefix --tasks "cell_type == 'classical monocyte'" --genes SPI1,CD68 - >>> decima modisco -o output_prefix -t tasks -o off_tasks -m model -m metadata -m method -m transform -m batch_size -m genes -m top_n_markers -m disable_bigwig -m disable_correct_grad_bigwig -m device -m genome -m num_workers - ... + >>> decima modisco -o output_prefix --tasks "cell_type == 'B cell'" --device 0 --genome hg38 - >>> decima modisco -o output_prefix -t tasks -o off_tasks -m model -m metadata -m method -m transform -m batch_size -m genes -m top_n_markers -m disable_bigwig -m disable_correct_grad_bigwig -m device -m genome -m num_workers - ... + >>> decima modisco -o output_prefix --genes SPI1 --method saliency --batch-size 2 """ import click from typing import List, Optional, Union +from decima.cli.callback import parse_model, parse_genes, parse_attributions from decima.interpret.modisco import ( predict_save_modisco_attributions, modisco_patterns, @@ -47,7 +45,7 @@ default=None, help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.", ) -@click.option("--model", type=str, default="0", help="Model to use for the prediction.") +@click.option("--model", type=str, default=0, help="Model to use for the prediction.", callback=parse_model) @click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") @click.option( "--method", @@ -64,7 +62,13 @@ help="Transform to use for attribution analysis.", ) @click.option("--batch-size", type=int, default=1, show_default=True, help="Batch size for the prediction.") -@click.option("--genes", type=str, default=None, help="Genes to predict. If not provided, all genes will be predicted.") +@click.option( + "--genes", + type=str, + default=None, + callback=parse_genes, + help="Genes to predict. If not provided, all genes will be predicted.", +) @click.option( "--top-n-markers", type=int, @@ -97,15 +101,6 @@ def cli_modisco_attributions( num_workers: int = 4, genome: str = "hg38", ): - if model in ["0", "1", "2", "3"]: # replicate index - model = int(model) - - if isinstance(device, str) and device.isdigit(): - device = int(device) - - if genes is not None: - genes = genes.split(",") - predict_save_modisco_attributions( output_prefix=output_prefix, tasks=tasks, @@ -131,7 +126,10 @@ def cli_modisco_attributions( "--attributions", type=str, required=True, - help="Path to the attributions HDF5 file. If multiple files are provided, they will be averaged.", + callback=parse_attributions, + help="Comma-separated list of paths to the attributions HDF5 files." + " If multiple files are provided, they will be averaged." + " All files must be h5 files generated with `decima modisco-attributions` command.", ) @click.option( "--tasks", @@ -147,7 +145,13 @@ def cli_modisco_attributions( ) @click.option("--tss-distance", type=int, default=10_000, show_default=True, help="TSS distance for the prediction.") @click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.") -@click.option("--genes", type=str, default=None, help="Genes to predict. If not provided, all genes will be predicted.") +@click.option( + "--genes", + type=str, + default=None, + callback=parse_genes, + help="Genes to predict. If not provided, all genes will be predicted.", +) @click.option( "--top-n-markers", type=int, @@ -209,12 +213,6 @@ def cli_modisco_patterns( stranded: bool = False, pattern_type: str = "both", ): - if isinstance(attributions, str): - attributions = attributions.split(",") - - if genes is not None: - genes = genes.split(",") - modisco_patterns( output_prefix=output_prefix, attributions=attributions, @@ -312,7 +310,14 @@ def cli_modisco_seqlet_bed( help="Set of tasks will be subtracted from the attributions to calculate attribution on `specificity` transform. If not provided, all tasks will be computed.", ) @click.option("--tss-distance", type=int, default=10_000, show_default=True, help="TSS distance for the prediction.") -@click.option("--model", type=str, default="ensemble", show_default=True, help="Model to use for the prediction.") +@click.option( + "--model", + type=str, + default="ensemble", + show_default=True, + help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files. Default: `ensemble`.", + callback=parse_model, +) @click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.") @click.option( "--method", @@ -327,6 +332,7 @@ def cli_modisco_seqlet_bed( type=str, show_default=True, default=None, + callback=parse_genes, help="Genes to predict. If not provided, all genes will be predicted.", ) @click.option( @@ -431,15 +437,6 @@ def cli_modisco( # seqlet thresholds seqlet_motif_trim_threshold: float = 0.2, ): - if model in ["0", "1", "2", "3"]: - model = int(model) - - if isinstance(device, str) and device.isdigit(): - device = int(device) - - if genes is not None: - genes = genes.split(",") - modisco( output_prefix=output_prefix, tasks=tasks, diff --git a/src/decima/cli/predict_genes.py b/src/decima/cli/predict_genes.py index 70a39d8..32a5934 100644 --- a/src/decima/cli/predict_genes.py +++ b/src/decima/cli/predict_genes.py @@ -8,6 +8,7 @@ import click from pathlib import Path +from decima.cli.callback import parse_model, parse_genes, validate_save_replicates from decima.tools.inference import predict_gene_expression @@ -17,14 +18,16 @@ "--genes", type=str, default=None, - help="List of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.", + callback=parse_genes, + help="Comma-separated list of genes to predict. Default: None (all genes). If provided, only these genes will be predicted.", ) @click.option( "-m", "--model", type=str, default="ensemble", - help="Path to the model checkpoint: `0`, `1`, `2`, `3`, `ensemble` or `path/to/model.ckpt`.", + callback=parse_model, + help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to checkpoint files", ) @click.option( "--metadata", @@ -45,7 +48,8 @@ @click.option( "--save-replicates", is_flag=True, - help="Save the replicates in the output parquet file. Default: False.", + callback=validate_save_replicates, + help="Save the replicates in the output h5ad file. Default: False. Only supported for ensemble models.", ) @click.option( "--float-precision", @@ -66,18 +70,6 @@ def cli_predict_genes( save_replicates, float_precision, ): - if model in ["0", "1", "2", "3"]: - model = int(model) - - if isinstance(device, str) and device.isdigit(): - device = int(device) - - if genes is not None: - genes = genes.split(",") - - if save_replicates and (model != "ensemble"): - raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).") - ad = predict_gene_expression( genes=genes, model=model, diff --git a/src/decima/cli/query_cell.py b/src/decima/cli/query_cell.py index 75fcc77..a6841c8 100644 --- a/src/decima/cli/query_cell.py +++ b/src/decima/cli/query_cell.py @@ -23,7 +23,10 @@ @click.command() @click.argument("query", default="") -def cli_query_cell(query=""): +@click.option( + "--metadata-anndata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file." +) +def cli_query_cell(query="", metadata_anndata=None): """ Query a cell using query string @@ -39,7 +42,7 @@ def cli_query_cell(query=""): ... """ - result = DecimaResult.load() + result = DecimaResult.load(metadata_anndata) df = result.cell_metadata if query != "": diff --git a/src/decima/cli/vep.py b/src/decima/cli/vep.py index 3d99167..795e74e 100644 --- a/src/decima/cli/vep.py +++ b/src/decima/cli/vep.py @@ -22,6 +22,7 @@ import click from decima.constants import DECIMA_CONTEXT_SIZE +from decima.cli.callback import parse_model, validate_save_replicates from decima.utils.dataframe import ensemble_predictions from decima.vep import predict_variant_effect @@ -46,7 +47,8 @@ "--model", type=str, default="ensemble", - help="Model to use for variant effect prediction either replicate number or path to the model.", + callback=parse_model, + help="`0`, `1`, `2`, `3`, `ensemble` or a path or a comma-separated list of paths to safetensor files to perform variant effect prediction. Default: `ensemble`.", ) @click.option( "--metadata", @@ -88,7 +90,8 @@ @click.option( "--save-replicates", is_flag=True, - help="Save the replicates in the output parquet file. Default: False.", + callback=validate_save_replicates, + help="Save the replicates in the output parquet file. Default: False. Only supported for ensemble models.", ) @click.option( "--disable-reference-cache", @@ -145,18 +148,9 @@ def cli_predict_variant_effect( """ reference_cache = not disable_reference_cache - if model in ["0", "1", "2", "3"]: # replicate index - model = int(model) - - if isinstance(device, str) and device.isdigit(): - device = int(device) - if include_cols: include_cols = include_cols.split(",") - if save_replicates and (model != "ensemble"): - raise ValueError("`--save-replicates` is only supported for ensemble model (`--model ensemble`).") - predict_variant_effect( variants, output_pq=output_pq, diff --git a/src/decima/core/result.py b/src/decima/core/result.py index 3c74abf..2c14d3b 100644 --- a/src/decima/core/result.py +++ b/src/decima/core/result.py @@ -208,11 +208,16 @@ def _pad_gene_metadata(self, gene_meta: pd.Series, padding: int = 0) -> pd.Serie gene_meta["gene_mask_end"] = gene_meta["gene_mask_end"] + padding return gene_meta - def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padding: int = 0) -> torch.Tensor: + def prepare_one_hot( + self, gene: str, variants: Optional[List[Dict]] = None, padding: int = 0, genome: str = "hg38" + ) -> torch.Tensor: """Prepare one-hot encoding for a gene. Args: gene: Gene name + variants: Optional list of variant dictionaries to inject into the sequence + padding: Amount of padding to add on both sides of the sequence + genome: Genome name or path to the genome fasta file. Default: "hg38" Returns: torch.Tensor: One-hot encoding of the gene @@ -221,10 +226,11 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padd gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding) if variants is None: - seq = intervals_to_strings(gene_meta, genome="hg38") + seq = intervals_to_strings(gene_meta, genome=genome) gene_start, gene_end = gene_meta.gene_mask_start, gene_meta.gene_mask_end else: - seq, (gene_start, gene_end) = prepare_seq_alt_allele(gene_meta, variants) + # Todo: fix for case where genome is not hg38 + seq, (gene_start, gene_end) = prepare_seq_alt_allele(gene_meta, variants, genome=genome) mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE + padding * 2)) mask[0, gene_start:gene_end] += 1 @@ -232,12 +238,13 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padd return strings_to_one_hot(seq), mask - def gene_sequence(self, gene: str, stranded: bool = True) -> str: + def gene_sequence(self, gene: str, stranded: bool = True, genome: str = "hg38") -> str: """Get sequence for a gene. Args: gene: Gene name stranded: Whether to return stranded sequence + genome: Genome name or path to the genome fasta file. Default: "hg38" Returns: str: Sequence for the gene @@ -250,7 +257,7 @@ def gene_sequence(self, gene: str, stranded: bool = True) -> str: gene_meta = self.gene_metadata.loc[gene] if not stranded: gene_meta = {"chrom": gene_meta.chrom, "start": gene_meta.start, "end": gene_meta.end} - return intervals_to_strings(gene_meta, genome="hg38") + return intervals_to_strings(gene_meta, genome=genome) def attributions( self, @@ -263,6 +270,7 @@ def attributions( min_seqlet_len: int = 4, max_seqlet_len: int = 25, additional_flanks: int = 0, + genome: str = "hg38", ): """Get attributions for a specific gene. @@ -272,15 +280,18 @@ def attributions( off_tasks: List of cells to use as off task transform: Attribution transform method method: Method to use for attribution analysis available options: "saliency", "inputxgradient", "integratedgradients". - n_peaks: Number of peaks to find - min_dist: Minimum distance between peaks + threshold: Threshold for attribution analysis + min_seqlet_len: Minimum length for seqlet calling + max_seqlet_len: Maximum length for seqlet calling + additional_flanks: Additional flanks for seqlet calling + genome: Genome to use for attribution analysis default is "hg38". Can be genome name or path to custom genome fasta file. Returns: Attribution: Container with inputs, predictions, attribution scores and TSS position """ tasks, off_tasks = self.query_tasks(tasks, off_tasks) - one_hot_seq, gene_mask = self.prepare_one_hot(gene) + one_hot_seq, gene_mask = self.prepare_one_hot(gene, genome=genome) inputs = torch.vstack([one_hot_seq, gene_mask]) attrs = ( diff --git a/src/decima/data/dataset.py b/src/decima/data/dataset.py index 0aaa46c..0c1a341 100644 --- a/src/decima/data/dataset.py +++ b/src/decima/data/dataset.py @@ -84,19 +84,30 @@ def __init__( self.gene_index = index_genes(self.h5_file, key=self.key) self.n_seqs = len(self.gene_index) - # Setup - self.dataset = h5py.File(self.h5_file, "r") + # Setup - Open file and cache data needed for worker processes + self.dataset = None + self._is_closed = False + self._open_file() self.extract_tasks(ad) self.predict = False self.n_alleles = 1 + def _open_file(self): + """Open the HDF5 file. This will be called in each worker process.""" + if self.dataset is None or self._is_closed: + self.dataset = h5py.File(self.h5_file, "r") + self._is_closed = False + def __len__(self): return self.n_seqs * self.n_augmented def close(self): - self.dataset.close() + if self.dataset is not None and not self._is_closed: + self.dataset.close() + self._is_closed = True def extract_tasks(self, ad=None): + self._open_file() tasks = np.array(self.dataset["tasks"]).astype(str) if ad is not None: assert np.all(tasks == ad.obs_names) @@ -105,6 +116,7 @@ def extract_tasks(self, ad=None): self.tasks = pd.DataFrame(index=tasks) def extract_seq(self, idx): + self._open_file() seq = self.dataset["sequences"][idx] seq = indices_to_one_hot(seq) # 4, L mask = self.dataset["masks"][[idx]] # 1, L @@ -113,6 +125,7 @@ def extract_seq(self, idx): return torch.Tensor(seq) def extract_label(self, idx): + self._open_file() return torch.Tensor(self.dataset["labels"][idx]) def __getitem__(self, idx): @@ -144,6 +157,7 @@ class GeneDataset(Dataset): max_seq_shift: Maximum sequence shift. seed: Seed for the random number generator. augment_mode: Augmentation mode. + genome: Name of the genome Returns: Dataset: Dataset for gene expression prediction. @@ -176,9 +190,11 @@ def __init__( max_seq_shift=0, seed=0, augment_mode="random", + genome="hg38", ): super().__init__() + self.genome = genome self.result = DecimaResult.load(metadata_anndata) self.genes = genes or list(self.result.genes) self.gene_mask_starts = self.result.gene_metadata.loc[self.genes, "gene_mask_start"].values @@ -209,7 +225,7 @@ def __len__(self): def __getitem__(self, idx): seq_idx, augment_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented)) - seq, mask = self.result.prepare_one_hot(self.genes[seq_idx], padding=self.max_seq_shift) + seq, mask = self.result.prepare_one_hot(self.genes[seq_idx], padding=self.max_seq_shift, genome=self.genome) inputs = torch.vstack([seq, mask]) inputs = self.augmenter(seq=inputs, idx=augment_idx) @@ -594,10 +610,12 @@ def __init__( max_distance=float("inf"), model_name=None, reference_cache=True, + genome="hg38", ): super().__init__() self.reference_cache = reference_cache + self.genome = genome self.result = DecimaResult.load(metadata_anndata) self.variants = self._overlap_genes( @@ -848,7 +866,7 @@ def __len__(self): return self.n_seqs * self.n_augmented * self.n_alleles def validate_allele_seq(self, gene, variant): - seq = self.result.gene_sequence(gene) + seq = self.result.gene_sequence(gene, genome=self.genome) pos = variant.rel_pos ref_match = seq[pos : pos + len(variant.ref)] == variant.ref_tx alt_match = seq[pos : pos + len(variant.alt)] == variant.alt_tx @@ -882,6 +900,7 @@ def __getitem__(self, idx): variant.gene, variants=[{"chrom": variant.chrom, "pos": variant.pos, "ref": variant.ref, "alt": variant.alt}], padding=self.max_seq_shift, + genome=self.genome, ) allele = seq[:, rel_pos : rel_pos + len(variant.alt)] allele_tx = variant.alt_tx @@ -896,6 +915,7 @@ def __getitem__(self, idx): variant.gene, variants=[{"chrom": variant.chrom, "pos": variant.pos, "ref": variant.alt, "alt": variant.ref}], padding=self.max_seq_shift, + genome=self.genome, ) allele = seq[:, rel_pos : rel_pos + len(variant.ref)] allele_tx = variant.ref_tx diff --git a/src/decima/data/write_hdf5.py b/src/decima/data/write_hdf5.py index 7194411..ccd3c1d 100644 --- a/src/decima/data/write_hdf5.py +++ b/src/decima/data/write_hdf5.py @@ -5,6 +5,14 @@ def write_hdf5(file, ad, pad=0, genome="hg38"): + """Write AnnData object to HDF5 file. + + Args: + file: Path to the HDF5 file to write + ad: AnnData object containing the data + pad: Amount of padding to add. Defaults to 0 + genome: Genome name or path to the genome fasta file. Defaults to "hg38" + """ # Calculate seq_len seq_len = get_unique_length(ad.var) diff --git a/src/decima/hub/__init__.py b/src/decima/hub/__init__.py index f19fc80..b6e7611 100644 --- a/src/decima/hub/__init__.py +++ b/src/decima/hub/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Union, Optional +from typing import Union, Optional, List import warnings import wandb from pathlib import Path @@ -17,7 +17,7 @@ def login_wandb(): wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), relogin=True, anonymous="must", timeout=0) -def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None): +def load_decima_model(model: Union[str, int, List[str]] = 0, device: Optional[str] = None): """Load a pre-trained Decima model from wandb or local path. Args: @@ -25,6 +25,7 @@ def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None): - int: Replicate number (0-3) - str: Model name on wandb - str: Path to local model checkpoint + - List: list of local model checkpoints device: Device to load the model on. If None, automatically selects the best available device. Returns: @@ -35,27 +36,36 @@ def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None): """ if isinstance(model, LightningModel): return model + elif model == "ensemble": - return EnsembleLightningModel( - [ - load_decima_model(0, device), - load_decima_model(1, device), - load_decima_model(2, device), - load_decima_model(3, device), - ] - ) + return EnsembleLightningModel([load_decima_model(i, device) for i in range(4)]) + + elif isinstance(model, List): + if len(model) == 1: + return load_decima_model(model[0], device) + else: + return EnsembleLightningModel([load_decima_model(path, device) for path in model]) + + elif model in {0, 1, 2, 3}: + model_name = f"rep{model}" + + # Load directly from a path elif isinstance(model, str): if Path(model).exists(): - return LightningModel.load_safetensor(model, device=device) + if model.endswith("ckpt"): + return LightningModel.load_from_checkpoint(model, map_location=device) + else: + return LightningModel.load_safetensor(model, device=device) else: model_name = model - elif model in {0, 1, 2, 3}: - model_name = f"rep{model}" + else: raise ValueError( - f"Invalid model: {model} it need to be a string of model_name on wandb " - "or an integer of replicate number {0, 1, 2, 3}, or a path to a local model" + f"Invalid model: {model} it needs to be either a string of model_names on wandb, " + "an integer of replicate number {0, 1, 2, 3}, a path to a local model or a list of paths." ) + + # If left with a model name, load from environment/wandb if model_name.upper() in os.environ: if Path(os.environ[model_name.upper()]).exists(): return LightningModel.load_safetensor(os.environ[model_name.upper()], device=device) diff --git a/src/decima/hub/download.py b/src/decima/hub/download.py index 02c0de0..9e2fe72 100644 --- a/src/decima/hub/download.py +++ b/src/decima/hub/download.py @@ -1,33 +1,97 @@ +from pathlib import Path +from typing import Union import logging import genomepy +from grelu.resources import get_artifact from decima.hub import login_wandb, load_decima_model, load_decima_metadata logger = logging.getLogger("decima") -def download_hg38(): +def cache_hg38(): """Download hg38 genome from UCSC.""" logger.info("Downloading hg38 genome...") genomepy.install_genome(provider="url", name="http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz") -def download_decima_weights(): +def cache_decima_weights(): """Download pre-trained Decima model weights from wandb.""" logger.info("Downloading Decima model weights...") for rep in range(4): load_decima_model(rep) -def download_decima_metadata(): +def cache_decima_metadata(): """Download pre-trained Decima model data from wandb.""" logger.info("Downloading Decima metadata...") load_decima_metadata() -def download_decima_data(): +def cache_decima_data(): """Download all required data for Decima.""" login_wandb() - download_hg38() - download_decima_weights() - download_decima_metadata() + cache_hg38() + cache_decima_weights() + cache_decima_metadata() + + +def download_decima_weights(model_name: Union[str, int], download_dir: str): + """Download pre-trained Decima model weights from wandb. + + Args: + model_name: Model name or replicate number. + download_dir: Directory to download the model weights. + + Returns: + Path to the downloaded model weights. + """ + if "ensemble" == model_name: + return [download_decima_weights(model, download_dir) for model in range(4)] + + if model_name in {0, 1, 2, 3}: + model_name = f"rep{model_name}" + + download_dir = Path(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Downloading Decima model weights for {model_name} to {download_dir / f'{model_name}.safetensors'}") + + art = get_artifact(model_name, project="decima") + art.download(str(download_dir)) + return download_dir / f"{model_name}.safetensors" + + +def download_decima_metadata(download_dir: str): + """Download pre-trained Decima model data from wandb. + + Args: + download_dir: Directory to download the metadata. + + Returns: + Path to the downloaded metadata. + """ + art = get_artifact("metadata", project="decima") + download_dir = Path(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Downloading Decima metadata to {download_dir / 'metadata.h5ad'}.") + + art.download(str(download_dir)) + return download_dir / "metadata.h5ad" + + +def download_decima(download_dir: str): + """Download all required data for Decima. + + Args: + download_dir: Directory to download the model weights and metadata. + + Returns: + Path to the downloaded directory containing the model weights and metadata. + """ + download_dir = Path(download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Downloading Decima model weights and metadata to {download_dir}:") + + download_decima_weights("ensemble", download_dir) + download_decima_metadata(download_dir) + return download_dir diff --git a/src/decima/interpret/attributions.py b/src/decima/interpret/attributions.py index 7e4450a..10388f9 100644 --- a/src/decima/interpret/attributions.py +++ b/src/decima/interpret/attributions.py @@ -67,8 +67,7 @@ def predict_save_attributions( device: Optional[str] = None, genome: str = "hg38", ): - """ - Generate and save attribution analysis results for a gene. + """Generate and save attribution analysis results for a gene. Args: output_prefix: Prefix for the output files where attribution results will be saved. @@ -88,6 +87,9 @@ def predict_save_attributions( device: Device to use for attribution analysis (e.g. 'cuda', 'cpu'). If not provided, the best available device will be used automatically. genome: Genome to use for attribution analysis default is "hg38". Can be genome name or path to custom genome fasta file. + Returns: + Path to the attribution file. + Examples: >>> predict_save_attributions( ... output_prefix="output_prefix", @@ -117,6 +119,33 @@ def predict_save_attributions( ... genome="hg38", ... ) """ + if (model == "ensemble") or isinstance(model, (list, tuple)): + if model == "ensemble": + models = [0, 1, 2, 3] + else: + models = model + return [ + predict_save_attributions( + output_prefix=(str(output_prefix) + "_{model}").format(model=idx), + tasks=tasks, + off_tasks=off_tasks, + model=model, + metadata_anndata=metadata_anndata, + method=method, + transform=transform, + batch_size=batch_size, + genes=genes, + seqs=seqs, + top_n_markers=top_n_markers, + bigwig=bigwig, + correct_grad_bigwig=correct_grad_bigwig, + num_workers=num_workers, + device=device, + genome=genome, + ) + for idx, model in enumerate(models) + ] + output_prefix = Path(output_prefix) output_prefix.parent.mkdir(parents=True, exist_ok=True) @@ -155,7 +184,7 @@ def predict_save_attributions( raise ValueError(f"Invalid type for seqs: {type(seqs)}. Must be a path to fasta file or pd.DataFrame.") else: dataset = GeneDataset( - genes=_get_genes(result, genes, top_n_markers, tasks, off_tasks), metadata_anndata=result + genes=_get_genes(result, genes, top_n_markers, tasks, off_tasks), metadata_anndata=result, genome=genome ) genes_batch = list(chunked(dataset.genes, batch_size)) @@ -170,8 +199,9 @@ def predict_save_attributions( num_workers=num_workers, ) + output_path = Path(output_prefix).with_suffix(".attributions.h5") with AttributionWriter( - path=Path(output_prefix).with_suffix(".attributions.h5"), + path=output_path, genes=dataset.genes, model_name=attributer.model.name, metadata_anndata=result, @@ -199,6 +229,8 @@ def predict_save_attributions( f.write(f">{i}\n{seq}\n") Faidx(fasta_path, build_index=True) + return output_path + def recursive_seqlet_calling( output_prefix: str, @@ -219,8 +251,7 @@ def recursive_seqlet_calling( custom_genome: bool = False, meme_motif_db: str = "hocomoco_v13", ): - """ - Recursive seqlet calling for attribution analysis. + """Recursive seqlet calling for attribution analysis. Args: output_prefix: Prefix for the output files where seqlet calling results will be saved. @@ -365,34 +396,22 @@ def predict_attributions_seqlet_calling( output_prefix = Path(output_prefix) output_prefix.parent.mkdir(parents=True, exist_ok=True) - if model == "ensemble": - attrs_output_prefix = str(output_prefix) + "_{model}" - models = [0, 1, 2, 3] - attributions = [ - Path(attrs_output_prefix.format(model=model)).with_suffix(".attributions.h5") for model in models - ] - else: - attrs_output_prefix = output_prefix - models = [model] - attributions = output_prefix.with_suffix(".attributions.h5").as_posix() - - for model in models: - predict_save_attributions( - output_prefix=str(attrs_output_prefix).format(model=model), - genes=genes, - seqs=seqs, - tasks=tasks, - off_tasks=off_tasks, - model=model, - metadata_anndata=metadata_anndata, - method=method, - transform=transform, - num_workers=num_workers, - batch_size=batch_size, - top_n_markers=top_n_markers, - device=device, - genome=genome, - ) + attributions_paths = predict_save_attributions( + output_prefix=output_prefix, + genes=genes, + seqs=seqs, + tasks=tasks, + off_tasks=off_tasks, + model=model, + metadata_anndata=metadata_anndata, + method=method, + transform=transform, + num_workers=num_workers, + batch_size=batch_size, + top_n_markers=top_n_markers, + device=device, + genome=genome, + ) custom_genome = False if seqs is not None: @@ -401,7 +420,7 @@ def predict_attributions_seqlet_calling( recursive_seqlet_calling( output_prefix=output_prefix, - attributions=attributions, + attributions=attributions_paths, metadata_anndata=metadata_anndata, genes=genes, tasks=tasks, diff --git a/src/decima/interpret/modisco.py b/src/decima/interpret/modisco.py index 0f9a449..3a4ddc2 100644 --- a/src/decima/interpret/modisco.py +++ b/src/decima/interpret/modisco.py @@ -79,7 +79,7 @@ def predict_save_modisco_attributions( ... tasks="cell_type == 'classical monocyte'", ... ) """ - predict_save_attributions( + return predict_save_attributions( output_prefix=output_prefix, tasks=tasks, off_tasks=off_tasks, @@ -533,37 +533,26 @@ def modisco( seqlet_motif_trim_threshold: Seqlet motif trim threshold. """ output_prefix = Path(output_prefix) + output_prefix.parent.mkdir(parents=True, exist_ok=True) - if model == "ensemble": - attrs_output_prefix = str(output_prefix) + "_{model}" - models = [0, 1, 2, 3] - attributions = [ - Path(attrs_output_prefix.format(model=model)).with_suffix(".attributions.h5") for model in models - ] - else: - attrs_output_prefix = output_prefix - models = [model] - attributions = [output_prefix.with_suffix(".attributions.h5").as_posix()] - - for model in models: - predict_save_modisco_attributions( - output_prefix=str(attrs_output_prefix).format(model=model), - tasks=tasks, - off_tasks=off_tasks, - model=model, - metadata_anndata=metadata_anndata, - genes=genes, - top_n_markers=top_n_markers, - method=method, - batch_size=batch_size, - correct_grad_bigwig=correct_grad, - device=device, - num_workers=num_workers, - genome=genome, - ) + attributions_paths = predict_save_modisco_attributions( + output_prefix=output_prefix, + tasks=tasks, + off_tasks=off_tasks, + model=model, + metadata_anndata=metadata_anndata, + genes=genes, + top_n_markers=top_n_markers, + method=method, + batch_size=batch_size, + correct_grad_bigwig=correct_grad, + device=device, + num_workers=num_workers, + genome=genome, + ) modisco_patterns( output_prefix=output_prefix, - attributions=attributions, + attributions=attributions_paths, tasks=tasks, off_tasks=off_tasks, tss_distance=tss_distance, diff --git a/src/decima/model/lightning.py b/src/decima/model/lightning.py index e21eca9..058c5a6 100644 --- a/src/decima/model/lightning.py +++ b/src/decima/model/lightning.py @@ -3,7 +3,6 @@ """ import json -from datetime import datetime from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -47,15 +46,15 @@ class LightningModel(pl.LightningModule): Wrapper for predictive sequence models Args: + name: Name of the model which be used in the generated results; thus, ensure unique name for each model and replicate. model_params: Dictionary of parameters specifying model architecture train_params: Dictionary specifying training parameters data_params: Dictionary specifying parameters of the training data. This is empty by default and will be filled at the time of training. - name: Name of the model. """ - def __init__(self, model_params: dict, train_params: dict = {}, data_params: dict = {}, name: str = "") -> None: + def __init__(self, name: str, model_params: dict, train_params: dict = {}, data_params: dict = {}) -> None: super().__init__() self.name = name @@ -210,16 +209,14 @@ def parse_logger(self) -> str: """ Parses the name of the logger supplied in train_params. """ - if "name" not in self.train_params: - self.train_params["name"] = datetime.now().strftime("%Y_%d_%m_%H_%M") if self.train_params["logger"] == "wandb": logger = WandbLogger( - name=self.train_params["name"], + name=self.name, log_model=True, save_dir=self.train_params["save_dir"], ) elif self.train_params["logger"] == "csv": - logger = CSVLogger(name=self.train_params["name"], save_dir=self.train_params["save_dir"]) + logger = CSVLogger(name=self.name, save_dir=self.train_params["save_dir"]) else: raise NotImplementedError return logger @@ -525,8 +522,9 @@ def load_safetensor(cls, path: str, device: str = "cpu"): class EnsembleLightningModel(LightningModel): - def __init__(self, models: List[LightningModel]): + def __init__(self, models: List[LightningModel], name="ensemble"): super().__init__( + name=name, model_params=models[0].model_params, train_params=models[0].train_params, data_params=models[0].data_params, @@ -636,7 +634,7 @@ def __init__( model_params: dict, train_params: dict = {}, data_params: dict = {}, - name: str = "", + name: str = "fix-gene-mask", ): super().__init__( model_params=model_params, diff --git a/src/decima/tools/inference.py b/src/decima/tools/inference.py index 0504444..2e5aa32 100644 --- a/src/decima/tools/inference.py +++ b/src/decima/tools/inference.py @@ -25,11 +25,12 @@ def predict_gene_expression( model (str, optional): Model to use for prediction. Defaults to 'ensemble'. metadata_anndata (str, optional): Path to the metadata anndata file. Defaults to None. device (str, optional): Device to use for prediction. Defaults to None. - batch_size (int, optional): Batch size for prediction. Defaults to 8. + batch_size (int, optional): Batch size for prediction. Defaults to 1. num_workers (int, optional): Number of workers for prediction. Defaults to 4. max_seq_shift (int, optional): Maximum sequence shift for prediction. Defaults to 0. - genome (str, optional): Genome build for prediction. Defaults to 'hg38'. + genome (str, optional): Genome name or path to the genome fasta file. Defaults to 'hg38'. save_replicates (bool, optional): Save the replicates for prediction. Defaults to False. + float_precision (str, optional): Floating-point precision. Defaults to "32". Raises: ValueError: If the model is not 'ensemble' and save_replicates is True. @@ -41,13 +42,15 @@ def predict_gene_expression( device = get_compute_device(device) logger.info(f"Using device: {device} and genome: {genome} for prediction.") + logger.info("Making predictions") model = load_decima_model(model, device=device) - ds = GeneDataset(genes=genes, metadata_anndata=metadata_anndata, max_seq_shift=max_seq_shift) + ds = GeneDataset(genes=genes, metadata_anndata=metadata_anndata, max_seq_shift=max_seq_shift, genome=genome) preds = model.predict_on_dataset( ds, devices=device, batch_size=batch_size, num_workers=num_workers, float_precision=float_precision ) + logger.info("Creating anndata") X = None if ds.result.anndata.X is not None: X = ds.result.anndata.X.copy() @@ -64,31 +67,43 @@ def predict_gene_expression( for model, pred in zip(model.models, preds["ensemble_preds"]): ad.layers[f"preds_{model.name}"] = pred.T - if ad.X is not None: - ad.var["pearson"] = [np.corrcoef(ad.X[:, i], ad.layers["preds"][:, i])[0, 1] for i in range(ad.shape[1])] - ad.var["size_factor_pearson"] = [ - np.corrcoef(ad.X[:, i], ad.obs["size_factor"])[0, 1] for i in range(ad.shape[1]) - ] - print( - f"Mean Pearson Correlation per gene: True: {round(ad.var.pearson.mean(), 2)}. " - f"Size Factor: {round(ad.var.size_factor_pearson.mean(), 2)}." - ) - - for dataset in ad.var.dataset.unique(): - key = f"{dataset}_pearson" - ad.obs[key] = [ - np.corrcoef( - ad[i, ad.var.dataset == dataset].X, - ad[i, ad.var.dataset == dataset].layers["preds"], - )[0, 1] - for i in range(ad.shape[0]) - ] - print(f"Mean Pearson Correlation per pseudobulk over {dataset} genes: {round(ad.obs[key].mean(), 2)}") - else: - del ad.var["pearson"] - del ad.var["size_factor_pearson"] - - for dataset in ad.var.dataset.unique(): - del ad.obs[f"{dataset}_pearson"] - + logger.info("Evaluating performance") + evaluate_gene_expression_predictions(ad) return ad + + +def evaluate_gene_expression_predictions(ad): + assert ad.X is not None, "ad.X is required for evaluation." + assert ad.layers["preds"] is not None, "ad.layers['preds'] is required for evaluation." + + n_pbs = ad.shape[0] + n_genes = ad.shape[1] + truth = ad.X + preds = ad.layers["preds"] + + # Compute Pearson correlation per gene + ad.var["pearson"] = [np.corrcoef(truth[:, i], preds[:, i])[0, 1] for i in range(n_genes)] + + if "size_factor" not in ad.obs.columns: + ad.obs["size_factor"] = ad.X.sum(1) + + ad.var["size_factor_pearson"] = [np.corrcoef(truth[:, i], ad.obs["size_factor"])[0, 1] for i in range(n_genes)] + + # compute correlations per pseudobulk + for dataset in ad.var.dataset.unique(): + in_dataset = ad.var.dataset == dataset + + key = f"{dataset}_pearson" + ad.obs[key] = [np.corrcoef(truth[i, in_dataset], preds[i, in_dataset])[0, 1] for i in range(n_pbs)] + + # Compute averages + mean_per_gene = ad.var.loc[in_dataset, "pearson"].mean() + mean_per_gene_sf = ad.var.loc[in_dataset, "size_factor_pearson"].mean() + mean_per_pb = ad.obs[key].mean() + + # Report results + print(f"Performance on genes in the {dataset} dataset.") + print(f"Mean Pearson Correlation per gene: Mean: {mean_per_gene:.2f}.") + print(f"Mean Pearson Correlation per gene using size factor (baseline): {mean_per_gene_sf:.2f}.") + print(f"Mean Pearson Correlation per pseudobulk: {mean_per_pb: .2f}") + print("") diff --git a/src/decima/utils/__init__.py b/src/decima/utils/__init__.py index 78d2f42..a5ce9b7 100644 --- a/src/decima/utils/__init__.py +++ b/src/decima/utils/__init__.py @@ -5,6 +5,16 @@ def _get_on_off_tasks(result: "DecimaResult", tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None): + """Get on and off tasks for attribution analysis. + + Args: + result: DecimaResult object containing cell metadata + tasks: List of task names or query string to filter tasks. If None, all tasks will be used. + off_tasks: List of off task names or query string to filter off tasks. + + Returns: + tuple: (tasks, off_tasks) as lists of task names + """ if tasks is None: tasks = result.cell_metadata.index.tolist() elif isinstance(tasks, str): @@ -22,6 +32,18 @@ def _get_genes( tasks: Optional[List[str]] = None, off_tasks: Optional[List[str]] = None, ): + """Get list of genes for analysis. + + Args: + result: DecimaResult object containing gene metadata + genes: List of gene names. If None, genes will be determined from other parameters. + top_n_markers: Number of top marker genes to select. If None, uses genes parameter. + tasks: List of task names for finding marker genes. + off_tasks: List of off task names for finding marker genes. + + Returns: + List[str]: List of gene names to analyze + """ if (top_n_markers is not None) and (genes is None): all_genes = ( result.marker_zscores(tasks=tasks, off_tasks=off_tasks) @@ -56,4 +78,8 @@ def get_compute_device(device: Optional[str] = None) -> torch.device: """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" + + elif isinstance(device, str) and device.isdigit(): + device = int(device) + return torch.device(device) diff --git a/src/decima/utils/inject.py b/src/decima/utils/inject.py index 1c6eed3..d09f099 100644 --- a/src/decima/utils/inject.py +++ b/src/decima/utils/inject.py @@ -17,9 +17,10 @@ class SeqBuilder: end: end position anchor: anchor position track: track positions shifts due to indels. + genome: Genome name or path to the genome fasta file. Defaults to "hg38". """ - def __init__(self, chrom: str, start: int, end: int, anchor: int, track: List[int] = None): + def __init__(self, chrom: str, start: int, end: int, anchor: int, track: List[int] = None, genome: str = "hg38"): self.chrom = chrom self.start = start self.end = end @@ -28,6 +29,7 @@ def __init__(self, chrom: str, start: int, end: int, anchor: int, track: List[in self.start_shift = 0 # how much interval is shifted to the left upstream self.end_shift = 0 # how much interval is shifted to the right downstream self.shifts = {pos: 0 for pos in track or list()} + self.genome = genome @staticmethod def _split_variant(variant, pos): @@ -139,7 +141,7 @@ def _construct(self) -> Generator[str, None, None]: start = self.start + self.start_shift end = self.end + self.end_shift - seq = intervals_to_strings({"chrom": self.chrom, "start": start, "end": end}, genome="hg38") + seq = intervals_to_strings({"chrom": self.chrom, "start": start, "end": end}, genome=self.genome) start += 1 # 0 based to 1 based start variants = sorted(self.variants, key=lambda x: x["pos"]) @@ -170,7 +172,7 @@ def concat(self) -> str: return "".join(self._construct()) -def prepare_seq_alt_allele(gene: GeneMetadata, variants: List[Dict]): +def prepare_seq_alt_allele(gene: GeneMetadata, variants: List[Dict], genome: str = "hg38"): """ Prepare the sequence and alt allele for a gene. @@ -200,7 +202,12 @@ def prepare_seq_alt_allele(gene: GeneMetadata, variants: List[Dict]): anchor = gene.gene_end if gene.strand == "-" else gene.gene_start builder = SeqBuilder( - chrom=gene.chrom, start=gene.start, end=gene.end, anchor=anchor, track=[gene.gene_start, gene.gene_end] + chrom=gene.chrom, + start=gene.start, + end=gene.end, + anchor=anchor, + track=[gene.gene_start, gene.gene_end], + genome=genome, ) for variant in variants: builder.inject(variant) diff --git a/src/decima/utils/sequence.py b/src/decima/utils/sequence.py index c809672..5459582 100644 --- a/src/decima/utils/sequence.py +++ b/src/decima/utils/sequence.py @@ -6,15 +6,15 @@ def prepare_mask_gene(gene_start, gene_end, padding=0): - """Mask a gene sequence with a padding. + """Prepare gene mask tensor for gene regions. Args: - gene_start: Start of the gene in the decima context window. - gene_end: End of the gene in the decima context window. - padding: Padding to add to the gene mask + gene_start: Start position of the gene + gene_end: End position of the gene + padding: Amount of padding to add on both sides. Defaults to 0 Returns: - torch.Tensor: Masked gene sequence with shape (1, DECIMA_CONTEXT_SIZE + padding * 2) + torch.Tensor: Gene mask tensor with 1s in gene region and 0s elsewhere """ mask = np.zeros(shape=(1, DECIMA_CONTEXT_SIZE + padding * 2)) mask[0, gene_start:gene_end] += 1 diff --git a/src/decima/vep/__init__.py b/src/decima/vep/__init__.py index 0aedb57..39eb33c 100644 --- a/src/decima/vep/__init__.py +++ b/src/decima/vep/__init__.py @@ -41,14 +41,18 @@ def _predict_variant_effect( tasks (str, optional): Tasks to predict. Defaults to None. model (int, optional): Model to use. Defaults to 0. metadata_anndata (str, optional): Path to anndata file. Defaults to None. - batch_size (int, optional): Batch size. Defaults to 8. + batch_size (int, optional): Batch size. Defaults to 1. num_workers (int, optional): Number of workers. Defaults to 16. - device (str, optional): Device to use. Defaults to "cpu". + device (str, optional): Device to use. Defaults to None. include_cols (list, optional): Columns to include in the output. Defaults to None. gene_col (str, optional): Column name for gene names. Defaults to None. distance_type (str, optional): Type of distance. Defaults to "tss". min_distance (float, optional): Minimum distance from the end of the gene. Defaults to 0 (inclusive). max_distance (float, optional): Maximum distance from the TSS. Defaults to inf (exclusive). + genome (str, optional): Genome name or path to the genome fasta file. Defaults to "hg38". + save_replicates (bool, optional): Save the replicates in the output. Defaults to False. + reference_cache (bool, optional): Whether to use reference cache. Defaults to True. + float_precision (str, optional): Floating-point precision. Defaults to "32". Returns: pd.DataFrame: DataFrame with variant effect predictions @@ -70,6 +74,7 @@ def _predict_variant_effect( max_distance=max_distance, model_name=model.name, reference_cache=reference_cache, + genome=genome, ) except ValueError as e: if str(e).startswith("NoOverlapError"): @@ -115,7 +120,7 @@ def predict_variant_effect( df_variant: Union[pd.DataFrame, str], output_pq: Optional[str] = None, tasks: Optional[Union[str, List[str]]] = None, - model: Union[int, str] = "ensemble", + model: Union[int, str, List[str]] = "ensemble", metadata_anndata: Optional[str] = None, chunksize: int = 10_000, batch_size: int = 1, @@ -134,21 +139,24 @@ def predict_variant_effect( """Predict variant effect and save to parquet Args: - df_variant (pd.DataFrame): DataFrame with variant information - output_path (str): Path to save the parquet file + df_variant (pd.DataFrame or str): DataFrame with variant information or path to variant file + output_pq (str, optional): Path to save the parquet file. Defaults to None. tasks (str, optional): Tasks to predict. Defaults to None. - model (int, optional): Model to use. Defaults to 0. + model (int, optional): Model to use. Defaults to "ensemble". metadata_anndata (str, optional): Path to anndata file. Defaults to None. chunksize (int, optional): Number of variants to predict in each chunk. Defaults to 10_000. - batch_size (int, optional): Batch size. Defaults to 8. + batch_size (int, optional): Batch size. Defaults to 1. num_workers (int, optional): Number of workers. Defaults to 16. - device (str, optional): Device to use. Defaults to "cpu". + device (str, optional): Device to use. Defaults to None. include_cols (list, optional): Columns to include in the output. Defaults to None. gene_col (str, optional): Column name for gene names. Defaults to None. distance_type (str, optional): Type of distance. Defaults to "tss". min_distance (float, optional): Minimum distance from the end of the gene. Defaults to 0 (inclusive). max_distance (float, optional): Maximum distance from the TSS. Defaults to inf (exclusive). - genome (str, optional): Genome build. Defaults to "hg38". + genome (str, optional): Genome name or path to the genome fasta file. Defaults to "hg38". + save_replicates (bool, optional): Save the replicates in the output. Defaults to False. + reference_cache (bool, optional): Whether to use reference cache. Defaults to True. + float_precision (str, optional): Floating-point precision. Defaults to "32". """ logger = logging.getLogger("decima") device = get_compute_device(device) diff --git a/tests/conftest.py b/tests/conftest.py index 6743b03..27ec0c8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from decima.constants import DECIMA_CONTEXT_SIZE from decima.hub import login_wandb -from decima.hub.download import download_hg38 +from decima.hub.download import cache_hg38 fasta_file = "tests/data/seqs.fasta" @@ -39,7 +39,7 @@ def pytest_collection_modifyitems(config, items): login_wandb() -download_hg38() +cache_hg38() device = "cpu" diff --git a/tests/scripts/generate_seq_fasta.py b/tests/scripts/generate_seq_fasta.py index 3361d94..a36a5fa 100644 --- a/tests/scripts/generate_seq_fasta.py +++ b/tests/scripts/generate_seq_fasta.py @@ -7,7 +7,6 @@ df = list() -# for i in ['CD68', 'SPI1', 'CD14']: for i in ['CD68', 'SPI1']: seq, _ = result.prepare_one_hot(i) seq = one_hot_to_strings(seq) diff --git a/tests/test_cli.py b/tests/test_cli.py index 5be29e0..60ba05c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -16,9 +16,30 @@ def test_cli_main(): @pytest.mark.long_running -def test_cli_download(): +def test_cli_cache(): runner = CliRunner() - result = runner.invoke(main, ["download"]) + result = runner.invoke(main, ["cache"]) + assert result.exit_code == 0 + + +@pytest.mark.long_running +def test_cli_download(tmp_path): + runner = CliRunner() + result = runner.invoke(main, ["download", "--download-dir", str(tmp_path)]) + assert result.exit_code == 0 + + +@pytest.mark.long_running +def test_cli_download_weights(tmp_path): + runner = CliRunner() + result = runner.invoke(main, ["download-weights", "--download-dir", str(tmp_path)]) + assert result.exit_code == 0 + + +@pytest.mark.long_running +def test_cli_download_metadata(tmp_path): + runner = CliRunner() + result = runner.invoke(main, ["download-metadata", "--download-dir", str(tmp_path)]) assert result.exit_code == 0 diff --git a/tests/test_interpret_attribution.py b/tests/test_interpret_attribution.py index 2f7f59f..96014e3 100644 --- a/tests/test_interpret_attribution.py +++ b/tests/test_interpret_attribution.py @@ -7,10 +7,12 @@ import numpy as np import pandas as pd from grelu.sequence.format import strings_to_one_hot + +from decima.constants import DECIMA_CONTEXT_SIZE +from decima.hub.download import download_decima_weights, download_decima_metadata from captum.attr import Saliency, InputXGradient, IntegratedGradients from decima.core.attribution import Attribution from decima import predict_attributions_seqlet_calling -from decima.constants import DECIMA_CONTEXT_SIZE from decima.interpret.attributer import DecimaAttributer, get_attribution_method from decima.interpret.attributions import predict_save_attributions, recursive_seqlet_calling, plot_attributions @@ -209,6 +211,33 @@ def test_predict_save_attributions_single_gene(tmp_path): assert ((plot_dir / "SPI1_seqlogos").is_dir()) +@pytest.mark.long_running +def test_predict_save_attributions_single_gene_list_models(tmp_path): + # download models + download_decima_weights(0, str(tmp_path)) + download_decima_weights(1, str(tmp_path)) + download_decima_metadata(str(tmp_path)) + + output_prefix = tmp_path / "SPI1" + predict_attributions_seqlet_calling( + output_prefix=output_prefix, + genes=["SPI1"], + metadata_anndata=str(tmp_path / "metadata.h5ad"), + tasks="cell_type == 'classical monocyte'", + model=[ + str(tmp_path / "rep0.safetensors"), + str(tmp_path / "rep1.safetensors"), + ], + device=device + ) + assert (output_prefix.with_suffix(".seqlets.bed")).exists() + assert Path(str(output_prefix) + "_0.attributions.h5").exists() + assert Path(str(output_prefix) + "_1.attributions.h5").exists() + assert (output_prefix.with_suffix(".motifs.tsv")).exists() + assert Path(str(output_prefix) + "_0.warnings.qc.log").exists() + assert Path(str(output_prefix) + "_1.warnings.qc.log").exists() + + @pytest.mark.long_running def test_predict_save_attributions_single_gene_saliency(tmp_path): output_prefix = tmp_path / "SPI1" diff --git a/tests/test_vep.py b/tests/test_vep.py index 1e4b17f..650ffd0 100644 --- a/tests/test_vep.py +++ b/tests/test_vep.py @@ -6,6 +6,7 @@ from scipy.stats import pearsonr from decima.core.result import DecimaResult +from decima.hub import load_decima_model from decima.data.dataset import VariantDataset from decima.model.metrics import WarningType from decima.vep import _predict_variant_effect, predict_variant_effect @@ -328,7 +329,7 @@ def test_predict_variant_effect_vcf_ensemble_replicates(tmp_path): assert output_file.exists() df_saved = pd.read_parquet(output_file) - assert df_saved.shape == (12, 44294) + assert df_saved.shape == (12, 14 + 8856 * 5) cells = list(df_saved.columns[14:8870]) average_preds = np.mean([ @@ -336,3 +337,33 @@ def test_predict_variant_effect_vcf_ensemble_replicates(tmp_path): for i in range(4) ], axis=0) np.testing.assert_allclose(df_saved[cells].values, average_preds, rtol=1e-5) + + +@pytest.mark.long_running +def test_predict_variant_effect_vcf_ensemble_replicates_model_list(tmp_path): + output_file = tmp_path / "test_predictions.parquet" + + models = [ + load_decima_model(0, device), + load_decima_model(1, device) + ] + + predict_variant_effect( + "tests/data/test.vcf", + output_pq=str(output_file), + model=models, + device=device, + max_distance=20000, + save_replicates=True, + ) + assert output_file.exists() + + df_saved = pd.read_parquet(output_file) + assert df_saved.shape == (12, 14 + 8856 * 3) + + cells = list(df_saved.columns[14:8870]) + average_preds = np.mean([ + df_saved[[f"{cell}_v1_rep{i}" for cell in cells]].values + for i in range(2) + ], axis=0) + np.testing.assert_allclose(df_saved[cells].values, average_preds, rtol=1e-5)