Fine-tuning Borzoi to create a Decima model¶
+import glob
+import anndata
+import scanpy as sc
+import pandas as pd
+import bioframe as bf
+import os
+outdir = "."
+ad_file_path = os.path.join(outdir, "data.h5ad")
+h5_file_path = os.path.join(outdir, "data.h5")
+1. Load input anndata file¶
+The input anndata file needs to be in the format (pseudobulks x genes).
+ad = sc.read(ad_file_path)
+ad
+AnnData object with n_obs × n_vars = 50 × 988 + obs: 'cell_type', 'tissue', 'disease', 'study' + var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset' + uns: 'log1p'+
.obs should be a dataframe with a unique index per pseudobulk. You can also include other columns with metadata about the pseudobulks, e.g. cell type, tissue, disease, study, number of cells, total counts.
Note that the original Decima model does NOT separate pseudobulks by sample, i.e. different samples from the same cell type, tissue, disease and study were merged. We also recommend filtering out pseudobulks with few cells or low read count.
+ad.obs.head()
+| + | cell_type | +tissue | +disease | +study | +
|---|---|---|---|---|
| pseudobulk_0 | +ct_0 | +t_0 | +d_0 | +st_0 | +
| pseudobulk_1 | +ct_0 | +t_0 | +d_1 | +st_0 | +
| pseudobulk_2 | +ct_0 | +t_0 | +d_2 | +st_1 | +
| pseudobulk_3 | +ct_0 | +t_0 | +d_0 | +st_1 | +
| pseudobulk_4 | +ct_0 | +t_0 | +d_1 | +st_2 | +
.var should be a dataframe with a unique index per gene. The index can be the gene name or Ensembl ID, as long as it is unique. Other essential columns are: chrom, start, end and strand (the gene coordinates).
You can also include other columns with metadata about the genes, e.g. Ensembl ID, type of gene.
+ad.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28320920 | +28845208 | ++ | +28484760 | +29009048 | +524288 | +163840 | +524288 | +train | +
| gene_1 | +chr19 | +39145337 | +39669625 | +- | +38981497 | +39505785 | +524288 | +163840 | +524288 | +train | +
| gene_2 | +chr1 | +77807946 | +78332234 | +- | +77644106 | +78168394 | +524288 | +163840 | +524288 | +train | +
| gene_3 | +chr8 | +143094013 | +143618301 | +- | +142930173 | +143454461 | +524288 | +163840 | +524288 | +val | +
| gene_4 | +chr16 | +1775288 | +2299576 | +- | +1611448 | +2135736 | +524288 | +163840 | +524288 | +train | +
.X should contain the total counts per gene and pseudobulk. These should be non-negative integers.
ad.X[:5, :5]
+array([[0. , 7.2155137, 7.3277392, 0. , 7.2698054], + [7.1914983, 7.3387527, 0. , 7.2105823, 7.180787 ], + [7.045969 , 7.2056117, 7.15802 , 7.289302 , 7.282388 ], + [7.2008514, 0. , 7.2667375, 7.321583 , 7.2398143], + [7.2582483, 6.723016 , 0. , 0. , 7.3626666]], + dtype=float32)+
2. Normalize and log transform data¶
+We first transform the counts to log(CPM+1) values. CPM = Counts Per Million.
+sc.pp.normalize_total(ad, target_sum=1e6)
+sc.pp.log1p(ad)
+WARNING: adata.X seems to be already log-transformed. ++
ad.X[:5, :5]
+array([[0. , 7.2337112, 7.2491336, 0. , 7.241202 ], + [7.2420583, 7.262313 , 0. , 7.244706 , 7.2405686], + [7.207595 , 7.229983 , 7.223361 , 7.2415223, 7.240574 ], + [7.2279363, 0. , 7.237038 , 7.2445517, 7.233329 ], + [7.2675843, 7.191038 , 0. , 0. , 7.281858 ]], + dtype=float32)+
3. Create intervals surrounding genes¶
+Decima is trained on 524,288 bp sequence surrounding the genes. Therefore, we have to take the given gene coordinates and extend them to create intervals of this length.
+from decima.data.preprocess import var_to_intervals
+/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type. + warnings.warn( +/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type. + warnings.warn( ++
ad.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28320920 | +28845208 | ++ | +28484760 | +29009048 | +524288 | +163840 | +524288 | +train | +
| gene_1 | +chr19 | +39145337 | +39669625 | +- | +38981497 | +39505785 | +524288 | +163840 | +524288 | +train | +
| gene_2 | +chr1 | +77807946 | +78332234 | +- | +77644106 | +78168394 | +524288 | +163840 | +524288 | +train | +
| gene_3 | +chr8 | +143094013 | +143618301 | +- | +142930173 | +143454461 | +524288 | +163840 | +524288 | +val | +
| gene_4 | +chr16 | +1775288 | +2299576 | +- | +1611448 | +2135736 | +524288 | +163840 | +524288 | +train | +
First, we copy the start and end columns to gene_start and gene_end. We also create a new column gene_length.
ad.var["gene_start"] = ad.var.start.tolist()
+ad.var["gene_end"] = ad.var.end.tolist()
+ad.var["gene_length"] = ad.var["gene_end"] - ad.var["gene_start"]
+ad.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28320920 | +28845208 | ++ | +28320920 | +28845208 | +524288 | +163840 | +524288 | +train | +
| gene_1 | +chr19 | +39145337 | +39669625 | +- | +39145337 | +39669625 | +524288 | +163840 | +524288 | +train | +
| gene_2 | +chr1 | +77807946 | +78332234 | +- | +77807946 | +78332234 | +524288 | +163840 | +524288 | +train | +
| gene_3 | +chr8 | +143094013 | +143618301 | +- | +143094013 | +143618301 | +524288 | +163840 | +524288 | +val | +
| gene_4 | +chr16 | +1775288 | +2299576 | +- | +1775288 | +2299576 | +524288 | +163840 | +524288 | +train | +
Now, we extend the gene coordinates to create enclosing intervals:
+ad = var_to_intervals(ad, chr_end_pad=10000, genome="hg38")
+# Replace genome name if necessary
+The interval size is 524288 bases. Of these, 163840 will be upstream of the gene start and 360448 will be downstream of the gene start. +0 intervals extended beyond the chromosome start and have been shifted +7 intervals extended beyond the chromosome end and have been shifted +7 intervals did not extend far enough upstream of the TSS and have been dropped ++
ad.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28157080 | +28681368 | ++ | +28320920 | +28845208 | +524288 | +163840 | +524288 | +train | +
| gene_1 | +chr19 | +39309177 | +39833465 | +- | +39145337 | +39669625 | +524288 | +163840 | +524288 | +train | +
| gene_2 | +chr1 | +77971786 | +78496074 | +- | +77807946 | +78332234 | +524288 | +163840 | +524288 | +train | +
| gene_3 | +chr8 | +143257853 | +143782141 | +- | +143094013 | +143618301 | +524288 | +163840 | +524288 | +val | +
| gene_4 | +chr16 | +1939128 | +2463416 | +- | +1775288 | +2299576 | +524288 | +163840 | +524288 | +train | +
You see that the columns start and end now contain the start and end coordinates for the 524,288 bp intervals.
3. Split genes into training, validation and test sets¶
+We load the coordinates of the genomic regions used to train Borzoi:
+splits_file = "https://raw.githubusercontent.com/calico/borzoi/main/data/sequences_human.bed.gz"
+# replace human with mouse for mm10 splits
+splits = pd.read_table(splits_file, header=None, names=["chrom", "start", "end", "fold"])
+splits.head()
+| + | chrom | +start | +end | +fold | +
|---|---|---|---|---|
| 0 | +chr4 | +82524421 | +82721029 | +fold0 | +
| 1 | +chr13 | +18604798 | +18801406 | +fold0 | +
| 2 | +chr2 | +189923408 | +190120016 | +fold0 | +
| 3 | +chr10 | +59875743 | +60072351 | +fold0 | +
| 4 | +chr1 | +117109467 | +117306075 | +fold0 | +
Now, we overlap our gene intervals with these regions:
+overlaps = bf.overlap(ad.var.reset_index(names="gene"), splits, how="left")
+overlaps = overlaps[["gene", "fold_"]].drop_duplicates().astype(str)
+overlaps.head()
+| + | gene | +fold_ | +
|---|---|---|
| 0 | +gene_0 | +fold5 | +
| 15 | +gene_1 | +fold0 | +
| 30 | +gene_2 | +fold0 | +
| 44 | +gene_3 | +fold4 | +
| 58 | +gene_4 | +fold0 | +
Based on the overlap, we divide our gene intervals into training, validation and test sets.
+test_genes = overlaps.gene[overlaps.fold_ == "fold3"].tolist()
+val_genes = overlaps.gene[overlaps.fold_ == "fold4"].tolist()
+train_genes = set(overlaps.gene).difference(set(test_genes).union(val_genes))
+And add this information back to ad.var.
ad.var["dataset"] = "test"
+ad.var.loc[ad.var.index.isin(val_genes), "dataset"] = "val"
+ad.var.loc[ad.var.index.isin(train_genes), "dataset"] = "train"
+/tmp/slurmjob.11843307/ipykernel_1446559/3109841685.py:1: ImplicitModificationWarning: Trying to modify attribute `.var` of view, initializing view as actual. ++
ad.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +
|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28157080 | +28681368 | ++ | +28320920 | +28845208 | +524288 | +163840 | +524288 | +train | +
| gene_1 | +chr19 | +39309177 | +39833465 | +- | +39145337 | +39669625 | +524288 | +163840 | +524288 | +train | +
| gene_2 | +chr1 | +77971786 | +78496074 | +- | +77807946 | +78332234 | +524288 | +163840 | +524288 | +train | +
| gene_3 | +chr8 | +143257853 | +143782141 | +- | +143094013 | +143618301 | +524288 | +163840 | +524288 | +val | +
| gene_4 | +chr16 | +1939128 | +2463416 | +- | +1775288 | +2299576 | +524288 | +163840 | +524288 | +train | +
ad.var.dataset.value_counts()
+dataset +train 803 +test 99 +val 79 +Name: count, dtype: int64+
We have now divided the 1000 genes in our dataset into separate sets to be used for training, validation and testing.
+4. Save processed anndata¶
+We will save the processed anndata file containing these intervals and data splits.
+ad.write_h5ad(ad_file_path)
+5. Create an hdf5 file¶
+To train Decima, we need to extract the genomic sequences for all the intervals and convert them to one-hot encoded format. We save these one-hot encoded inputs to an hdf5 file.
+from decima.data.write_hdf5 import write_hdf5
+write_hdf5(file=h5_file_path, ad=ad, pad=5000, genome="hg38")
+# Change genome name if necessary
+Writing metadata +Writing task indices +Writing genes array of shape: (981, 2) +Writing labels array of shape: (981, 50, 1) +Making gene masks ++
Writing mask array of shape: (981, 534288) ++
Encoding sequences ++
Writing sequence array of shape: (981, 534288) +Done! ++
6. Set training parameters¶
+# Learning rate default=0.001
+lr = 5e-5
+# Total weight parameter for the loss function
+total_weight = 1e-4
+# Gradient accumulation steps
+grad = 5
+# batch-size. default=4
+bs = 4
+# max-seq-shift. default=5000
+shift = 5000
+# Number of epochs. Default 1
+epochs = 15
+
+# logger
+logger = "wandb" # Change to csv to save logs locally
+
+# Number of workers default=16
+workers = 16
+7. Generate training commands¶
+cmds = []
+
+for model in range(4):
+ name = f"finetune_test_{model}"
+ device = model
+
+ cmd = (
+ f"decima finetune --name {name} "
+ + f"--model {model} --device {device} "
+ + f"--matrix-file {ad_file_path} --h5-file {h5_file_path} "
+ + f"--outdir {outdir} --learning-rate {lr} "
+ + f"--loss-total-weight {total_weight} --gradient-accumulation {grad} "
+ + f"--batch-size {bs} --max-seq-shift {shift} "
+ + f"--epochs {epochs} --logger {logger} --num-workers {workers}"
+ )
+ cmds.append(cmd)
+for cmd in cmds:
+ print(cmd)
+decima finetune --name finetune_test_0 --model 0 --device 0 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16 +decima finetune --name finetune_test_1 --model 1 --device 1 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16 +decima finetune --name finetune_test_2 --model 2 --device 2 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16 +decima finetune --name finetune_test_3 --model 3 --device 3 --matrix-file ./data.h5ad --h5-file ./data.h5 --outdir . --learning-rate 5e-05 --loss-total-weight 0.0001 --gradient-accumulation 5 --batch-size 4 --max-seq-shift 5000 --epochs 15 --logger wandb --num-workers 16 ++
# ! CUDA_VISIBLE_DEVICES=0 decima finetune \
+# --name finetune_test_0 \
+# --model 0 \
+# --device 0 \
+# --matrix-file {ad_file_path} \
+# --h5-file {h5_file_path} \
+# --outdir {outdir} \
+# --learning-rate {lr} \
+# --loss-total-weight {total_weight} \
+# --gradient-accumulation {grad} "
+# --batch-size {bs} \
+# --max-seq-shift {shift} \
+# --epochs {epochs} \
+# --logger {logger} \
+# --num-workers {workers}
+# Uncomment if necessary
+#import wandb
+#wandb.login(host="https://genentech.wandb.io", anonymous="never", relogin=True)
+8. Make and evaluate predictions using trained models¶
+Using the training commands above, we trained two model replicates. Now, we can use these models to predict gene expression:
+checkpoint = glob.glob('lightning_logs/*/checkpoints/epoch=0-step=42.ckpt')[0]
+print(checkpoint)
+lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt ++
# comma-separated list of model checkpoints
+checkpoint_list = ",".join([checkpoint, checkpoint])
+checkpoint_list
+'lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt,lightning_logs/ie4tgmpg/checkpoints/epoch=0-step=42.ckpt'+
! CUDA_VISIBLE_DEVICES=0 decima predict-genes \
+--output test_preds.h5ad \
+--model {checkpoint_list} \
+--metadata data.h5ad \
+--device 0 \
+--batch-size 8 \
+--num-workers 32 \
+--max_seq_shift 0 \
+--genome hg38 \
+--save-replicates
+/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'repr' attribute with value False was provided to the `Field()` function, which has no effect in the context it was used. 'repr' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type. + warnings.warn( +/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'frozen' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'frozen' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type. + warnings.warn( ++
decima - INFO - Using device: cuda:0 and genome: hg38 for prediction. +decima - INFO - Making predictions ++
decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0 ++
wandb: Currently logged in as: anony-mouse-591272909468377997 to https://api.wandb.ai. Use `wandb login --relogin` to force relogin ++
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
+
+wandb: 1 of 1 files downloaded.
+Done. 00:00:00.6 (1235.1MB/s)
+
+decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0 ++
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
+
+wandb: 1 of 1 files downloaded.
+Done. 00:00:00.5 (1299.7MB/s)
+
+decima - INFO - Initializing weights from Borzoi model using wandb for replicate: 0 ++
wandb: Downloading large artifact 'human_state_dict_fold0:latest', 709.30MB. 1 files...
+
+wandb: 1 of 1 files downloaded.
+Done. 00:00:00.6 (1244.1MB/s)
+
+/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/__init__.py:1617: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.) +/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:204: PossibleUserWarning: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python3.11 /home/celikm5/miniforge3/envs/decima2/bin/decima ... +💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry. ++
GPU available: True (cuda), used: True +TPU available: False, using: 0 TPU cores +HPU available: False, using: 0 HPUs +/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torch/utils/data/dataloader.py:627: UserWarning: This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. ++
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] ++
+Predicting: | | 0/? [00:00<?, ?it/s]+
+Predicting: | | 0/? [00:00<?, ?it/s] +Predicting DataLoader 0: 0%| | 0/123 [00:00<?, ?it/s]+
+Predicting DataLoader 0: 1%|▏ | 1/123 [00:02<04:57, 0.41it/s]+
+Predicting DataLoader 0: 2%|▎ | 2/123 [00:04<04:03, 0.50it/s]+
+Predicting DataLoader 0: 2%|▍ | 3/123 [00:05<03:48, 0.52it/s]+
+Predicting DataLoader 0: 3%|▌ | 4/123 [00:07<03:40, 0.54it/s]+
+Predicting DataLoader 0: 4%|▋ | 5/123 [00:09<03:34, 0.55it/s]+
+Predicting DataLoader 0: 5%|▉ | 6/123 [00:10<03:30, 0.56it/s]+
+Predicting DataLoader 0: 6%|█ | 7/123 [00:12<03:26, 0.56it/s]+
+Predicting DataLoader 0: 7%|█▏ | 8/123 [00:14<03:23, 0.56it/s]+
+Predicting DataLoader 0: 7%|█▎ | 9/123 [00:15<03:21, 0.57it/s]+
+Predicting DataLoader 0: 8%|█▍ | 10/123 [00:17<03:18, 0.57it/s]+
+Predicting DataLoader 0: 9%|█▌ | 11/123 [00:19<03:16, 0.57it/s]+
+Predicting DataLoader 0: 10%|█▋ | 12/123 [00:20<03:13, 0.57it/s]+
+Predicting DataLoader 0: 11%|█▊ | 13/123 [00:22<03:11, 0.57it/s]+
+Predicting DataLoader 0: 11%|█▉ | 14/123 [00:24<03:09, 0.57it/s]+
+Predicting DataLoader 0: 12%|██ | 15/123 [00:26<03:07, 0.58it/s]+
+Predicting DataLoader 0: 13%|██▏ | 16/123 [00:27<03:05, 0.58it/s]+
+Predicting DataLoader 0: 14%|██▎ | 17/123 [00:29<03:03, 0.58it/s]+
+Predicting DataLoader 0: 15%|██▍ | 18/123 [00:31<03:01, 0.58it/s]+
+Predicting DataLoader 0: 15%|██▋ | 19/123 [00:32<02:59, 0.58it/s]+
+Predicting DataLoader 0: 16%|██▊ | 20/123 [00:34<02:58, 0.58it/s]+
+Predicting DataLoader 0: 17%|██▉ | 21/123 [00:36<02:56, 0.58it/s]+
+Predicting DataLoader 0: 18%|███ | 22/123 [00:37<02:54, 0.58it/s]+
+Predicting DataLoader 0: 19%|███▏ | 23/123 [00:39<02:52, 0.58it/s]+
+Predicting DataLoader 0: 20%|███▎ | 24/123 [00:41<02:50, 0.58it/s]+
+Predicting DataLoader 0: 20%|███▍ | 25/123 [00:43<02:48, 0.58it/s]+
+Predicting DataLoader 0: 21%|███▌ | 26/123 [00:44<02:47, 0.58it/s]+
+Predicting DataLoader 0: 22%|███▋ | 27/123 [00:46<02:45, 0.58it/s]+
+Predicting DataLoader 0: 23%|███▊ | 28/123 [00:48<02:43, 0.58it/s]+
+Predicting DataLoader 0: 24%|████ | 29/123 [00:49<02:41, 0.58it/s]+
+Predicting DataLoader 0: 24%|████▏ | 30/123 [00:51<02:39, 0.58it/s]+
+Predicting DataLoader 0: 25%|████▎ | 31/123 [00:53<02:38, 0.58it/s]+
+Predicting DataLoader 0: 26%|████▍ | 32/123 [00:54<02:36, 0.58it/s]+
+Predicting DataLoader 0: 27%|████▌ | 33/123 [00:56<02:34, 0.58it/s]+
+Predicting DataLoader 0: 28%|████▋ | 34/123 [00:58<02:32, 0.58it/s]+
+Predicting DataLoader 0: 28%|████▊ | 35/123 [01:00<02:31, 0.58it/s]+
+Predicting DataLoader 0: 29%|████▉ | 36/123 [01:01<02:29, 0.58it/s]+
+Predicting DataLoader 0: 30%|█████ | 37/123 [01:03<02:27, 0.58it/s]+
+Predicting DataLoader 0: 31%|█████▎ | 38/123 [01:05<02:25, 0.58it/s]+
+Predicting DataLoader 0: 32%|█████▍ | 39/123 [01:06<02:24, 0.58it/s]+
+Predicting DataLoader 0: 33%|█████▌ | 40/123 [01:08<02:22, 0.58it/s]+
+Predicting DataLoader 0: 33%|█████▋ | 41/123 [01:10<02:20, 0.58it/s]+
+Predicting DataLoader 0: 34%|█████▊ | 42/123 [01:11<02:18, 0.58it/s]+
+Predicting DataLoader 0: 35%|█████▉ | 43/123 [01:13<02:17, 0.58it/s]+
+Predicting DataLoader 0: 36%|██████ | 44/123 [01:15<02:15, 0.58it/s]+
+Predicting DataLoader 0: 37%|██████▏ | 45/123 [01:17<02:13, 0.58it/s]+
+Predicting DataLoader 0: 37%|██████▎ | 46/123 [01:18<02:11, 0.58it/s]+
+Predicting DataLoader 0: 38%|██████▍ | 47/123 [01:20<02:10, 0.58it/s]+
+Predicting DataLoader 0: 39%|██████▋ | 48/123 [01:22<02:08, 0.58it/s]+
+Predicting DataLoader 0: 40%|██████▊ | 49/123 [01:23<02:06, 0.58it/s]+
+Predicting DataLoader 0: 41%|██████▉ | 50/123 [01:25<02:04, 0.58it/s]+
+Predicting DataLoader 0: 41%|███████ | 51/123 [01:27<02:03, 0.58it/s]+
+Predicting DataLoader 0: 42%|███████▏ | 52/123 [01:29<02:01, 0.58it/s]+
+Predicting DataLoader 0: 43%|███████▎ | 53/123 [01:30<01:59, 0.58it/s]+
+Predicting DataLoader 0: 44%|███████▍ | 54/123 [01:32<01:58, 0.58it/s]+
+Predicting DataLoader 0: 45%|███████▌ | 55/123 [01:34<01:56, 0.58it/s]+
+Predicting DataLoader 0: 46%|███████▋ | 56/123 [01:35<01:54, 0.58it/s]+
+Predicting DataLoader 0: 46%|███████▉ | 57/123 [01:37<01:52, 0.58it/s]+
+Predicting DataLoader 0: 47%|████████ | 58/123 [01:39<01:51, 0.58it/s]+
+Predicting DataLoader 0: 48%|████████▏ | 59/123 [01:40<01:49, 0.58it/s]+
+Predicting DataLoader 0: 49%|████████▎ | 60/123 [01:42<01:47, 0.58it/s]+
+Predicting DataLoader 0: 50%|████████▍ | 61/123 [01:44<01:46, 0.58it/s]+
+Predicting DataLoader 0: 50%|████████▌ | 62/123 [01:46<01:44, 0.58it/s]+
+Predicting DataLoader 0: 51%|████████▋ | 63/123 [01:47<01:42, 0.58it/s]+
+Predicting DataLoader 0: 52%|████████▊ | 64/123 [01:49<01:40, 0.58it/s]+
+Predicting DataLoader 0: 53%|████████▉ | 65/123 [01:51<01:39, 0.58it/s]+
+Predicting DataLoader 0: 54%|█████████ | 66/123 [01:52<01:37, 0.58it/s]+
+Predicting DataLoader 0: 54%|█████████▎ | 67/123 [01:54<01:35, 0.58it/s]+
+Predicting DataLoader 0: 55%|█████████▍ | 68/123 [01:56<01:34, 0.59it/s]+
+Predicting DataLoader 0: 56%|█████████▌ | 69/123 [01:57<01:32, 0.59it/s]+
+Predicting DataLoader 0: 57%|█████████▋ | 70/123 [01:59<01:30, 0.59it/s]+
+Predicting DataLoader 0: 58%|█████████▊ | 71/123 [02:01<01:28, 0.59it/s]+
+Predicting DataLoader 0: 59%|█████████▉ | 72/123 [02:03<01:27, 0.59it/s]+
+Predicting DataLoader 0: 59%|██████████ | 73/123 [02:04<01:25, 0.59it/s]+
+Predicting DataLoader 0: 60%|██████████▏ | 74/123 [02:06<01:23, 0.59it/s]+
+Predicting DataLoader 0: 61%|██████████▎ | 75/123 [02:08<01:22, 0.59it/s]+
+Predicting DataLoader 0: 62%|██████████▌ | 76/123 [02:09<01:20, 0.59it/s]+
+Predicting DataLoader 0: 63%|██████████▋ | 77/123 [02:11<01:18, 0.59it/s]+
+Predicting DataLoader 0: 63%|██████████▊ | 78/123 [02:13<01:16, 0.59it/s]+
+Predicting DataLoader 0: 64%|██████████▉ | 79/123 [02:14<01:15, 0.59it/s]+
+Predicting DataLoader 0: 65%|███████████ | 80/123 [02:16<01:13, 0.59it/s]+
+Predicting DataLoader 0: 66%|███████████▏ | 81/123 [02:18<01:11, 0.59it/s]+
+Predicting DataLoader 0: 67%|███████████▎ | 82/123 [02:20<01:10, 0.59it/s]+
+Predicting DataLoader 0: 67%|███████████▍ | 83/123 [02:21<01:08, 0.59it/s]+
+Predicting DataLoader 0: 68%|███████████▌ | 84/123 [02:23<01:06, 0.59it/s]+
+Predicting DataLoader 0: 69%|███████████▋ | 85/123 [02:25<01:04, 0.59it/s]+
+Predicting DataLoader 0: 70%|███████████▉ | 86/123 [02:26<01:03, 0.59it/s]+
+Predicting DataLoader 0: 71%|████████████ | 87/123 [02:28<01:01, 0.59it/s]+
+Predicting DataLoader 0: 72%|████████████▏ | 88/123 [02:30<00:59, 0.59it/s]+
+Predicting DataLoader 0: 72%|████████████▎ | 89/123 [02:31<00:58, 0.59it/s]+
+Predicting DataLoader 0: 73%|████████████▍ | 90/123 [02:33<00:56, 0.59it/s]+
+Predicting DataLoader 0: 74%|████████████▌ | 91/123 [02:35<00:54, 0.59it/s]+
+Predicting DataLoader 0: 75%|████████████▋ | 92/123 [02:37<00:52, 0.59it/s]+
+Predicting DataLoader 0: 76%|████████████▊ | 93/123 [02:38<00:51, 0.59it/s]+
+Predicting DataLoader 0: 76%|████████████▉ | 94/123 [02:40<00:49, 0.59it/s]+
+Predicting DataLoader 0: 77%|█████████████▏ | 95/123 [02:42<00:47, 0.59it/s]+
+Predicting DataLoader 0: 78%|█████████████▎ | 96/123 [02:43<00:46, 0.59it/s]+
+Predicting DataLoader 0: 79%|█████████████▍ | 97/123 [02:45<00:44, 0.59it/s]+
+Predicting DataLoader 0: 80%|█████████████▌ | 98/123 [02:47<00:42, 0.59it/s]+
+Predicting DataLoader 0: 80%|█████████████▋ | 99/123 [02:49<00:40, 0.59it/s]+
+Predicting DataLoader 0: 81%|█████████████ | 100/123 [02:50<00:39, 0.59it/s]+
+Predicting DataLoader 0: 82%|█████████████▏ | 101/123 [02:52<00:37, 0.59it/s]+
+Predicting DataLoader 0: 83%|█████████████▎ | 102/123 [02:54<00:35, 0.59it/s]+
+Predicting DataLoader 0: 84%|█████████████▍ | 103/123 [02:55<00:34, 0.59it/s]+
+Predicting DataLoader 0: 85%|█████████████▌ | 104/123 [02:57<00:32, 0.59it/s]+
+Predicting DataLoader 0: 85%|█████████████▋ | 105/123 [02:59<00:30, 0.59it/s]+
+Predicting DataLoader 0: 86%|█████████████▊ | 106/123 [03:00<00:29, 0.59it/s]+
+Predicting DataLoader 0: 87%|█████████████▉ | 107/123 [03:02<00:27, 0.59it/s]+
+Predicting DataLoader 0: 88%|██████████████ | 108/123 [03:04<00:25, 0.59it/s]+
+Predicting DataLoader 0: 89%|██████████████▏ | 109/123 [03:06<00:23, 0.59it/s]+
+Predicting DataLoader 0: 89%|██████████████▎ | 110/123 [03:07<00:22, 0.59it/s]+
+Predicting DataLoader 0: 90%|██████████████▍ | 111/123 [03:09<00:20, 0.59it/s]+
+Predicting DataLoader 0: 91%|██████████████▌ | 112/123 [03:11<00:18, 0.59it/s]+
+Predicting DataLoader 0: 92%|██████████████▋ | 113/123 [03:12<00:17, 0.59it/s]+
+Predicting DataLoader 0: 93%|██████████████▊ | 114/123 [03:14<00:15, 0.59it/s]+
+Predicting DataLoader 0: 93%|██████████████▉ | 115/123 [03:16<00:13, 0.59it/s]+
+Predicting DataLoader 0: 94%|███████████████ | 116/123 [03:17<00:11, 0.59it/s]+
+Predicting DataLoader 0: 95%|███████████████▏| 117/123 [03:19<00:10, 0.59it/s]+
+Predicting DataLoader 0: 96%|███████████████▎| 118/123 [03:21<00:08, 0.59it/s]+
+Predicting DataLoader 0: 97%|███████████████▍| 119/123 [03:23<00:06, 0.59it/s]+
+Predicting DataLoader 0: 98%|███████████████▌| 120/123 [03:24<00:05, 0.59it/s]+
+Predicting DataLoader 0: 98%|███████████████▋| 121/123 [03:26<00:03, 0.59it/s]+
+Predicting DataLoader 0: 99%|███████████████▊| 122/123 [03:28<00:01, 0.59it/s]+
+Predicting DataLoader 0: 100%|████████████████| 123/123 [03:29<00:00, 0.59it/s]+
+Predicting DataLoader 0: 100%|████████████████| 123/123 [03:29<00:00, 0.59it/s] ++
/home/celikm5/miniforge3/envs/decima2/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric WarningCounter was called before the ``update`` method which may lead to errors, as metric states have not yet been updated. +decima - INFO - Creating anndata +decima - INFO - Evaluating performance ++
Performance on genes in the train dataset. +Mean Pearson Correlation per gene: Mean: 0.00. +Mean Pearson Correlation per gene using size factor (baseline): 0.02. +Mean Pearson Correlation per pseudobulk: 0.00 + +Performance on genes in the val dataset. +Mean Pearson Correlation per gene: Mean: 0.04. +Mean Pearson Correlation per gene using size factor (baseline): 0.05. +Mean Pearson Correlation per pseudobulk: 0.02 + ++
Performance on genes in the test dataset. +Mean Pearson Correlation per gene: Mean: -0.02. +Mean Pearson Correlation per gene using size factor (baseline): 0.01. +Mean Pearson Correlation per pseudobulk: 0.01 + ++
We can open the output h5ad file to see the individual predictions and metrics.
+ad_out = anndata.read_h5ad("test_preds.h5ad")
+ad_out
+AnnData object with n_obs × n_vars = 50 × 981 + obs: 'cell_type', 'tissue', 'disease', 'study', 'size_factor', 'train_pearson', 'val_pearson', 'test_pearson' + var: 'chrom', 'start', 'end', 'strand', 'gene_start', 'gene_end', 'gene_length', 'gene_mask_start', 'gene_mask_end', 'dataset', 'pearson', 'size_factor_pearson' + layers: 'preds', 'preds_finetune_test_0'+
.layers['preds_0'] and .layers['preds_1'] contain the predictions made by the individual models whereas .layers['preds_0'] contains the average predictions. You will see that performance metrics have been added to both .obs and .var.
ad_out.obs.head()
+| + | cell_type | +tissue | +disease | +study | +size_factor | +train_pearson | +val_pearson | +test_pearson | +
|---|---|---|---|---|---|---|---|---|
| pseudobulk_0 | +ct_0 | +t_0 | +d_0 | +st_0 | +5193.049805 | +0.070174 | +0.214402 | +0.088188 | +
| pseudobulk_1 | +ct_0 | +t_0 | +d_1 | +st_0 | +5137.830566 | +-0.004344 | +0.058580 | +-0.015836 | +
| pseudobulk_2 | +ct_0 | +t_0 | +d_2 | +st_1 | +5198.248535 | +0.022892 | +0.212270 | +-0.026279 | +
| pseudobulk_3 | +ct_0 | +t_0 | +d_0 | +st_1 | +5204.543457 | +0.067001 | +-0.053795 | +0.041648 | +
| pseudobulk_4 | +ct_0 | +t_0 | +d_1 | +st_2 | +5056.311523 | +0.009684 | +0.001823 | +0.020882 | +
ad_out.var.head()
+| + | chrom | +start | +end | +strand | +gene_start | +gene_end | +gene_length | +gene_mask_start | +gene_mask_end | +dataset | +pearson | +size_factor_pearson | +
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| gene_0 | +chr1 | +28157080 | +28681368 | ++ | +28320920 | +28845208 | +524288 | +163840 | +524288 | +train | +0.042477 | +-0.036051 | +
| gene_1 | +chr19 | +39309177 | +39833465 | +- | +39145337 | +39669625 | +524288 | +163840 | +524288 | +train | +0.041681 | +-0.075098 | +
| gene_2 | +chr1 | +77971786 | +78496074 | +- | +77807946 | +78332234 | +524288 | +163840 | +524288 | +train | +-0.070010 | +0.220900 | +
| gene_3 | +chr8 | +143257853 | +143782141 | +- | +143094013 | +143618301 | +524288 | +163840 | +524288 | +val | +-0.104826 | +0.128605 | +
| gene_4 | +chr16 | +1939128 | +2463416 | +- | +1775288 | +2299576 | +524288 | +163840 | +524288 | +train | +-0.082712 | +-0.001255 | +