diff --git a/scripts/run_benchmark/run_full_local.sh b/scripts/run_benchmark/run_full_local.sh index 20e434b3..b60940c9 100755 --- a/scripts/run_benchmark/run_full_local.sh +++ b/scripts/run_benchmark/run_full_local.sh @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "transcriptformer_mlflow"]}' HERE # run the benchmark diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index 85e39583..4b7bf15e 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "transcriptformer_mlflow"]}' HERE nextflow run . \ diff --git a/src/methods/transcriptformer_mlflow/config.vsh.yaml b/src/methods/transcriptformer_mlflow/config.vsh.yaml new file mode 100644 index 00000000..2d144c23 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/config.vsh.yaml @@ -0,0 +1,67 @@ +__merge__: ../../api/base_method.yaml + +name: transcriptformer_mlflow +label: TranscriptFormer (MLflow model) +summary: "Context-aware representations of single-cell transcriptomes by jointly modeling genes and transcripts" +description: | + TranscriptFormer is designed to learn rich, context-aware representations of + single-cell transcriptomes while jointly modeling genes and transcripts using + a novel generative architecture. + + It is a family of generative foundation models representing a cross-species + generative cell atlas trained on up to 112 million cells spanning 1.53 billion + years of evolution across 12 species. + + Here, we use a version packaged as an MLflow model. +references: + doi: + - 10.1101/2025.04.25.650731 +links: + documentation: https://github.com/czi-ai/transcriptformer#readme + repository: https://github.com/czi-ai/transcriptformer + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the transcriptformer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + setup: + - type: docker + add: https://astral.sh/uv/0.7.19/install.sh /uv-installer.sh + run: sh /uv-installer.sh && rm /uv-installer.sh + env: PATH="/root/.local/bin/:$PATH" + - type: docker + run: uv venv --python 3.11 /opt/venv + - type: docker + env: + - VIRTUAL_ENV=/opt/venv + - PATH="/opt/venv/bin:$PATH" + add: requirements.txt /requirements.txt + run: uv pip install -r /requirements.txt + - type: docker + run: uv pip install mlflow==3.1.0 + - type: docker + run: uv pip install git+https://github.com/openproblems-bio/core#subdirectory=packages/python/openproblems + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/transcriptformer_mlflow/requirements.txt b/src/methods/transcriptformer_mlflow/requirements.txt new file mode 100644 index 00000000..70d923d1 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/requirements.txt @@ -0,0 +1,338 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o requirements.txt +aiobotocore==2.23.0 + # via s3fs +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.13 + # via + # aiobotocore + # fsspec + # s3fs +aioitertools==0.12.0 + # via aiobotocore +aiosignal==1.3.2 + # via aiohttp +anndata==0.11.4 + # via + # cellxgene-census + # scanpy + # somacore + # tiledbsoma + # transcriptformer +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via + # aiohttp + # somacore + # tiledbsoma +boto3==1.38.27 + # via transcriptformer +botocore==1.38.27 + # via + # aiobotocore + # boto3 + # s3transfer +cellxgene-census==1.17.0 + # via transcriptformer +certifi==2025.6.15 + # via requests +charset-normalizer==3.4.2 + # via requests +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +filelock==3.18.0 + # via + # torch + # triton +fonttools==4.58.4 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.5.1 + # via + # pytorch-lightning + # s3fs + # torch +h5py==3.14.0 + # via + # anndata + # scanpy + # transcriptformer +hydra-core==1.3.2 + # via transcriptformer +idna==3.10 + # via + # requests + # yarl +iniconfig==2.1.0 + # via pytest +jinja2==3.1.6 + # via torch +jmespath==1.0.1 + # via + # aiobotocore + # boto3 + # botocore +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.8 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +lightning-utilities==0.14.3 + # via + # pytorch-lightning + # torchmetrics +llvmlite==0.44.0 + # via + # numba + # pynndescent +markupsafe==3.0.2 + # via jinja2 +matplotlib==3.10.3 + # via + # scanpy + # seaborn +more-itertools==10.7.0 + # via tiledbsoma +mpmath==1.3.0 + # via sympy +multidict==6.6.0 + # via + # aiobotocore + # aiohttp + # yarl +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # pynndescent + # scanpy + # umap-learn +numpy==2.2.6 + # via + # anndata + # cellxgene-census + # contourpy + # h5py + # matplotlib + # numba + # pandas + # patsy + # scanpy + # scikit-learn + # scipy + # seaborn + # shapely + # somacore + # statsmodels + # tiledbsoma + # torchmetrics + # transcriptformer + # umap-learn +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.575.51 + # via pynvml +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +omegaconf==2.3.0 + # via hydra-core +packaging==25.0 + # via + # anndata + # hydra-core + # lightning-utilities + # matplotlib + # pytest + # pytorch-lightning + # scanpy + # statsmodels + # torchmetrics +pandas==2.3.0 + # via + # anndata + # scanpy + # seaborn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer +patsy==1.0.1 + # via + # scanpy + # statsmodels +pillow==11.2.1 + # via matplotlib +pluggy==1.6.0 + # via pytest +propcache==0.3.2 + # via + # aiohttp + # yarl +psutil==7.0.0 + # via transcriptformer +pyarrow==20.0.0 + # via + # somacore + # tiledbsoma +pyarrow-hotfix==0.7 + # via somacore +pygments==2.19.2 + # via pytest +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pynvml==12.0.0 + # via transcriptformer +pyparsing==3.2.3 + # via matplotlib +pytest==8.4.1 + # via transcriptformer +python-dateutil==2.9.0.post0 + # via + # aiobotocore + # botocore + # matplotlib + # pandas +pytorch-lightning==2.5.2 + # via transcriptformer +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # omegaconf + # pytorch-lightning +requests==2.32.4 + # via cellxgene-census +s3fs==2025.5.1 + # via cellxgene-census +s3transfer==0.13.0 + # via boto3 +scanpy==1.11.2 + # via + # tiledbsoma + # transcriptformer +scikit-learn==1.7.0 + # via + # pynndescent + # scanpy + # umap-learn +scipy==1.16.0 + # via + # anndata + # pynndescent + # scanpy + # scikit-learn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer + # umap-learn +seaborn==0.13.2 + # via scanpy +session-info2==0.1.2 + # via scanpy +setuptools==80.9.0 + # via lightning-utilities +shapely==2.1.1 + # via somacore +six==1.17.0 + # via python-dateutil +somacore==1.0.28 + # via tiledbsoma +statsmodels==0.14.4 + # via scanpy +sympy==1.13.1 + # via torch +threadpoolctl==3.6.0 + # via scikit-learn +tiledbsoma==1.17.0 + # via cellxgene-census +timeout-decorator==0.5.0 + # via transcriptformer +torch==2.5.1 + # via + # pytorch-lightning + # torchmetrics + # transcriptformer +torchmetrics==1.7.3 + # via pytorch-lightning +tqdm==4.67.1 + # via + # pytorch-lightning + # scanpy + # umap-learn +transcriptformer==0.3.0 + # via -r requirements.in +triton==3.1.0 + # via torch +typing-extensions==4.14.0 + # via + # cellxgene-census + # lightning-utilities + # pytorch-lightning + # scanpy + # somacore + # tiledbsoma + # torch +tzdata==2025.2 + # via pandas +umap-learn==0.5.7 + # via scanpy +urllib3==2.5.0 + # via + # botocore + # requests +wrapt==1.17.2 + # via aiobotocore +yarl==1.20.1 + # via aiohttp diff --git a/src/methods/transcriptformer_mlflow/script.py b/src/methods/transcriptformer_mlflow/script.py new file mode 100644 index 00000000..b16806d3 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/script.py @@ -0,0 +1,112 @@ +import os +import sys +import tarfile +import tempfile +import zipfile + +import anndata as ad +import mlflow.pyfunc +import pandas as pd + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "transcriptformer_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable +from read_anndata_partial import read_anndata + +print("====== TranscriptFormer (MLflow model) ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"Transcriptformer can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +if os.path.isdir(par["model"]): + print("\n>>> Using model directory...", flush=True) + print(f"Directory path: '{par['model']}'", flush=True) + model_temp = None + model_dir = par["model"] +else: + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + + if zipfile.is_zipfile(par["model"]): + print("\n>>> Extracting model from .zip...", flush=True) + print(f".zip path: '{par['model']}'", flush=True) + with zipfile.ZipFile(par["model"], "r") as zip_file: + zip_file.extractall(model_dir) + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"): + print("\n>>> Extracting model from .tar.gz...", flush=True) + print(f".tar.gz path: '{par['model']}'", flush=True) + with tarfile.open(par["model"], "r:gz") as tar_file: + tar_file.extractall(model_dir) + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) + else: + raise ValueError( + "The 'model' argument should be a directory a .zip file or a .tar.gz file" + ) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +print("\n>>> Writing temporary input H5AD file...", flush=True) +input_adata = ad.AnnData( + X=adata.X.copy(), + var=adata.var.filter(items=["feature_id"]).rename( + columns={"feature_id": "ensembl_id"} + ), +) +input_adata.obs["assay"] = "unknown" # Avoid error if assay is missing +print(input_adata, flush=True) +h5ad_file = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) +print(f"Temporary H5AD file: '{h5ad_file.name}'", flush=True) +input_adata.write(h5ad_file.name) +del input_adata + +print("\n>>> Running model...", flush=True) +input_df = pd.DataFrame({"input_uri": [h5ad_file.name]}) +embedding = model.predict(input_df) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() +h5ad_file.close() +os.unlink(h5ad_file.name) + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 09905ad0..d9fe9504 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -106,6 +106,7 @@ dependencies: - name: methods/scimilarity - name: methods/scprint - name: methods/scvi + - name: methods/transcriptformer_mlflow - name: methods/uce # metrics - name: metrics/asw_batch diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 6196f749..104485bd 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -40,6 +40,9 @@ methods = [ ), scprint, scvi, + transcriptformer_mlflow.run( + args: [model: file("s3://openproblems-work/cache/transcriptformer-mlflow-model.zip")] + ), uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] )