Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,9 @@

## NEW FUNCTIONALITY

* Added `control_methods/true_labels` component (PR #5).

* Added `methods/logistic_regression` component (PR #5).

* Added `metrics/accuracy` component (PR #5).

## MAJOR CHANGES

* Updated `api` files (PR #5).

* Updated configs, components and CI to the latest Viash version (PR #8).

## MINOR CHANGES

* Updated `README.md` (PR #5).

## BUGFIXES

60 changes: 60 additions & 0 deletions src/methods/cellplm/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
__merge__: ../../api/base_method.yaml

# A unique identifier for your component (required).
# Can contain only lowercase letters or underscores.
name: cellplm
# A relatively short label, used when rendering visualisations (required)
label: CellPLM
# A one sentence summary of how this method works (required). Used when
# rendering summary tables.
summary: "A foundation model pre-trained with cells as tokens."
# A multi-line description of how this component works (required). Used
# when rendering reference documentation.
description: |
CellPLM is a pre-trained language model specifically designed for single-cell analysis that leverages the principles of natural language processing (NLP) to understand and process single-cell gene expression data.
references:
doi:
- 10.1101/2023.10.03.560734
links:
# URL to the documentation for this method (required).
documentation: https://github.com/OmicsML/CellPLM/tree/main/tutorials
# URL to the code repository for this method (required).
repository: https://github.com/OmicsML/CellPLM


info:
method_types: [embedding]
preferred_normalization: counts

arguments:
- name: --model
type: string
description: String giving the CellPLM model to use
choices: ["20231027_85M"]
default: "20231027_85M"
# - name: --n_hvg
# type: integer
# default: 3000
# description: Number of highly variable genes to use.

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
- cellplm

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
96 changes: 96 additions & 0 deletions src/methods/cellplm/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import sys
import tempfile
import scanpy as sc
import anndata as ad
import gdown
import torch

import warnings
warnings.filterwarnings("ignore")
from CellPLM.utils import set_seed

import numpy as np
import anndata as ad
from CellPLM.pipeline.cell_embedding import CellEmbeddingPipeline

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.

par = {
'input': 'resources_test/.../input.h5ad',
'output': 'output.h5ad',
"model": "20231027_85M",
}
meta = {
'name': 'cellplm'
}
## VIASH END

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

set_seed(24)
PRETRAIN_VERSION = par['model']
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("\n>>> Reading input files...", flush=True)
print(f"Input H5AD file: '{par['input']}'", flush=True)
adata = read_anndata(
par['input'],
X='layers/normalized',
obs='obs',
var='var',
uns='uns'
)

if adata.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"CellPLM can only be used with human data "
f"(dataset_organism == \"{adata.uns['dataset_organism']}\")"
)

print(adata, flush=True)

print('Preprocess data', flush=True)
# ... preprocessing ...

print('Train model', flush=True)
# ... train model ...

drive_path = f"https://drive.google.com/drive/folders/1C2fVNEKX3plHnagaTwpuPW5tpwv1up9G?usp=sharing"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)

pipeline = CellEmbeddingPipeline(pretrain_prefix=PRETRAIN_VERSION, # Specify the pretrain checkpoint to load
pretrain_directory=model_dir.name)

# DEVICE ='cpu'
embedding = pipeline.predict(adata, # An AnnData object
device=DEVICE) # Specify a gpu or cpu for model inference

embedding = embedding.cpu().numpy()

print('Generate predictions', flush=True)
# ... generate predictions ...

output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedding,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

output.write_h5ad(par['output'], compression='gzip')

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()
Loading