diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f1c3368..f37e5404 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## NEW FUNCTIONALITY * Add new method: CellMapper, which is a k-NN based approach to map cells across representations and can be used for label projection. Two versions are included here, one based on PCA or CCA embeddings (`linear`) and one based on an scvi embedding (`scvi`) (PR #22) +* Add MLflow-based methods: Geneformer, scGPT, scVI, TranscriptFormer, and UCE for label projection using pre-trained foundation models (PR #28) ## MAJOR CHANGES diff --git a/src/methods/geneformer_mlflow/config.vsh.yaml b/src/methods/geneformer_mlflow/config.vsh.yaml new file mode 100644 index 00000000..506a32da --- /dev/null +++ b/src/methods/geneformer_mlflow/config.vsh.yaml @@ -0,0 +1,50 @@ +__merge__: ../../api/base_method.yaml + +name: geneformer_mlflow +label: Geneformer (MLflow) +summary: Geneformer is a foundational transformer model pretrained on a large-scale corpus of single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology (MLflow) +description: | + Geneformer is a context-aware, attention-based deep learning model pretrained + on a large-scale corpus of single-cell transcriptomes to enable + context-specific predictions in settings with limited data in network biology. + + This version uses a pre-trained MLflow model. A kNN classifier is trained on + embeddings for the training data and used to predict labels for the test + data. It does not use the built-in Geneformer classifier. +references: + doi: + - 10.1038/s41586-023-06139-9 + - 10.1101/2024.08.16.608180 +links: + documentation: https://geneformer.readthedocs.io/en/latest/index.html + repository: https://huggingface.co/ctheodoris/Geneformer + +info: + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the Geneformer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/geneformer_mlflow/requirements.txt b/src/methods/geneformer_mlflow/requirements.txt new file mode 100644 index 00000000..21bec26b --- /dev/null +++ b/src/methods/geneformer_mlflow/requirements.txt @@ -0,0 +1,540 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --output-file=/tmp/tmpmz65ifid/requirements_pip_final.txt requirements.in +# +absl-py==2.3.1 + # via tensorboard +accelerate==1.10.0 + # via peft +accumulation-tree==0.6.4 + # via tdigest +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via + # mlflow + # optuna +anndata==0.10.9 + # via + # -r requirements.in + # geneformer + # scanpy +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via omegaconf +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via + # aiohttp + # jsonschema + # referencing +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +click==8.2.1 + # via + # flask + # loompy + # mlflow-skinny + # ray + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +colorlog==6.9.0 + # via optuna +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +datasets==4.0.0 + # via geneformer +dill==0.3.8 + # via + # datasets + # multiprocess +docker==7.1.0 + # via mlflow +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # datasets + # huggingface-hub + # ray + # torch + # transformers +flask==3.1.1 + # via mlflow +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +geneformer @ git+https://huggingface.co/ctheodoris/Geneformer@69e6887 + # via -r requirements.in +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +grpcio==1.74.0 + # via tensorboard +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # loompy + # scanpy +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via + # accelerate + # datasets + # peft + # tokenizers + # transformers +idna==3.10 + # via + # anyio + # requests + # yarl +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +itsdangerous==2.2.0 + # via flask +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +jsonschema==4.25.0 + # via ray +jsonschema-specifications==2025.4.1 + # via jsonschema +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +llvmlite==0.44.0 + # via + # numba + # pynndescent +loompy==3.0.8 + # via geneformer +mako==1.3.10 + # via alembic +markdown==3.8.2 + # via tensorboard +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # geneformer + # mlflow + # scanpy + # seaborn +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via ray +multidict==6.6.4 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # loompy + # pynndescent + # scanpy + # umap-learn +numpy==2.2.6 + # via + # accelerate + # anndata + # contourpy + # datasets + # geneformer + # h5py + # loompy + # matplotlib + # mlflow + # numba + # numpy-groupies + # optuna + # pandas + # patsy + # peft + # scanpy + # scikit-learn + # scipy + # seaborn + # statsmodels + # tensorboard + # transformers + # umap-learn +numpy-groupies==0.11.3 + # via loompy +nvidia-cublas-cu12==12.8.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.8.90 + # via torch +nvidia-cuda-nvrtc-cu12==12.8.93 + # via torch +nvidia-cuda-runtime-cu12==12.8.90 + # via torch +nvidia-cudnn-cu12==9.10.2.21 + # via torch +nvidia-cufft-cu12==11.3.3.83 + # via torch +nvidia-cufile-cu12==1.13.1.3 + # via torch +nvidia-curand-cu12==10.3.9.90 + # via torch +nvidia-cusolver-cu12==11.7.3.90 + # via torch +nvidia-cusparse-cu12==12.5.8.93 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.7.1 + # via torch +nvidia-nccl-cu12==2.27.3 + # via torch +nvidia-nvjitlink-cu12==12.8.93 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.8.90 + # via torch +omegaconf==2.3.0 + # via -r requirements.in +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +optuna==4.4.0 + # via + # geneformer + # optuna-integration +optuna-integration==4.4.0 + # via geneformer +packaging==25.0 + # via + # accelerate + # anndata + # datasets + # geneformer + # gunicorn + # huggingface-hub + # matplotlib + # mlflow-skinny + # optuna + # peft + # ray + # scanpy + # statsmodels + # tensorboard + # transformers +pandas==2.3.1 + # via + # anndata + # datasets + # geneformer + # mlflow + # scanpy + # seaborn + # statsmodels +patsy==1.0.1 + # via + # scanpy + # statsmodels +peft==0.17.0 + # via geneformer +pillow==11.3.0 + # via + # matplotlib + # tensorboard +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # ray + # tensorboard +psutil==7.0.0 + # via + # accelerate + # peft +pyarrow==20.0.0 + # via + # datasets + # geneformer + # mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytz==2025.2 + # via + # geneformer + # pandas +pyudorandom==1.0.0 + # via tdigest +pyyaml==6.0.2 + # via + # accelerate + # datasets + # huggingface-hub + # mlflow-skinny + # omegaconf + # optuna + # peft + # ray + # transformers +ray==2.48.0 + # via geneformer +referencing==0.36.2 + # via + # jsonschema + # jsonschema-specifications +regex==2025.7.34 + # via transformers +requests==2.32.4 + # via + # databricks-sdk + # datasets + # docker + # huggingface-hub + # mlflow-skinny + # ray + # transformers +rpds-py==0.27.0 + # via + # jsonschema + # referencing +rsa==4.9.1 + # via google-auth +safetensors==0.6.2 + # via + # accelerate + # peft + # transformers +scanpy==1.11.4 + # via geneformer +scikit-learn==1.7.1 + # via + # geneformer + # mlflow + # pynndescent + # scanpy + # umap-learn +scipy==1.16.1 + # via + # anndata + # geneformer + # loompy + # mlflow + # pynndescent + # scanpy + # scikit-learn + # statsmodels + # umap-learn +seaborn==0.13.2 + # via + # geneformer + # scanpy +session-info2==0.2 + # via scanpy +six==1.17.0 + # via python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow + # optuna +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via + # geneformer + # scanpy +sympy==1.14.0 + # via torch +tdigest==0.5.2.2 + # via geneformer +tensorboard==2.20.0 + # via geneformer +tensorboard-data-server==0.7.2 + # via tensorboard +threadpoolctl==3.6.0 + # via scikit-learn +tokenizers==0.21.4 + # via transformers +torch==2.8.0 + # via + # accelerate + # geneformer + # peft +tqdm==4.67.1 + # via + # datasets + # geneformer + # huggingface-hub + # optuna + # peft + # scanpy + # transformers + # umap-learn +transformers==4.49.0 + # via + # -r requirements.in + # geneformer + # peft +triton==3.4.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # fastapi + # graphene + # huggingface-hub + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # pydantic + # pydantic-core + # referencing + # scanpy + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via scanpy +urllib3==2.5.0 + # via + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via + # flask + # tensorboard +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/src/methods/geneformer_mlflow/script.py b/src/methods/geneformer_mlflow/script.py new file mode 100644 index 00000000..e0356c06 --- /dev/null +++ b/src/methods/geneformer_mlflow/script.py @@ -0,0 +1,109 @@ +import os +import sys +import anndata as ad +import mlflow +import numpy as np + +## VIASH START +par = { + "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", + "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "geneformer_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import train_classifier, classify # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== Geneformer (MLflow model) ======", flush=True) + +n_processors = meta.get("cpus") or os.cpu_count() +print(f"Available processors: {n_processors}", flush=True) + +print("\n>>> Reading training data...", flush=True) +print(f"Training H5AD file: '{par['input_train']}'", flush=True) +input_train = ad.read_h5ad(par["input_train"]) +print(input_train, flush=True) + +if input_train.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"Geneformer (MLflow) can only be used with human data " + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' + ) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + + +def process_geneformer_adata(adata): + """Add required columns for Geneformer model.""" + adata.obs["cell_idx"] = np.arange(adata.n_obs) + adata.obs["n_counts"] = adata.X.sum(axis=1) + + +# Train classifier on training data +classifier = train_classifier( + input_train, + model, + layers=["counts"], + var={"feature_id": "ensembl_id"}, + model_params={"nproc": n_processors}, + process_adata=process_geneformer_adata, +) + +# Free memory - no longer need training data +del input_train + +print("\n>>> Reading test data...", flush=True) +print(f"Test H5AD file: '{par['input_test']}'", flush=True) +input_test = ad.read_h5ad(par["input_test"]) +print(input_test, flush=True) + +# Store metadata before classifying +dataset_id = input_test.uns["dataset_id"] +normalization_id = input_test.uns["normalization_id"] + +# Classify test data +predictions = classify( + input_test, + model, + classifier, + layers=["counts"], + var={"feature_id": "ensembl_id"}, + model_params={"nproc": n_processors}, + process_adata=process_geneformer_adata, +) + +# Free memory - no longer need test data +del input_test + +print(predictions.value_counts(), flush=True) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs={"label_pred": predictions}, + uns={ + "method_id": meta["name"], + "dataset_id": dataset_id, + "normalization_id": normalization_id, + }, +) +print(output, flush=True) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/scgpt_mlflow/config.vsh.yaml b/src/methods/scgpt_mlflow/config.vsh.yaml new file mode 100644 index 00000000..5ab238f9 --- /dev/null +++ b/src/methods/scgpt_mlflow/config.vsh.yaml @@ -0,0 +1,48 @@ +__merge__: ../../api/base_method.yaml + +name: scgpt_mlflow +label: scGPT (MLflow) +summary: scGPT is a foundation model for single-cell biology (MLflow model) +description: | + scGPT is a foundation model for single-cell biology based on a generative + pre-trained transformer and trained on a repository of over 33 million cells. + + This version uses a pre-trained MLflow model. A kNN classifier is trained on + embeddings for the training data and used to predict labels for the test + data. +references: + doi: + - 10.1038/s41592-024-02201-0 +links: + documentation: https://scgpt.readthedocs.io/en/latest/ + repository: https://github.com/bowang-lab/scGPT + +info: + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the scGPT model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/scgpt_mlflow/requirements.txt b/src/methods/scgpt_mlflow/requirements.txt new file mode 100644 index 00000000..2ad53dc3 --- /dev/null +++ b/src/methods/scgpt_mlflow/requirements.txt @@ -0,0 +1,684 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmp7yfkiop2/requirements_initial.txt +absl-py==2.3.1 + # via + # chex + # ml-collections + # optax + # orbax + # orbax-checkpoint +aiofiles==24.1.0 + # via orbax-checkpoint +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via + # datasets + # fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via mlflow +anndata==0.10.9 + # via + # -r requirements.in + # mudata + # scanpy + # scib + # scvi-tools +annotated-types==0.7.0 + # via pydantic +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via aiohttp +blinker==1.9.0 + # via flask +cached-property==2.0.1 + # via orbax +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +cell-gears==0.0.2 + # via scgpt +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +chex==0.1.90 + # via + # optax + # scvi-tools +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +datasets==2.14.4 + # via scgpt +dcor==0.6 + # via cell-gears +decorator==5.2.1 + # via ipython +deprecated==1.2.18 + # via scib +dill==0.3.7 + # via + # datasets + # multiprocess +docker==7.1.0 + # via mlflow +docrep==0.3.2 + # via scvi-tools +et-xmlfile==2.0.0 + # via openpyxl +etils==1.13.0 + # via + # orbax + # orbax-checkpoint +exceptiongroup==1.3.0 + # via + # anndata + # anyio + # ipython +executing==2.2.0 + # via stack-data +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # huggingface-hub + # torch + # triton +flask==3.1.1 + # via mlflow +flax==0.10.7 + # via scvi-tools +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.7.0 + # via + # datasets + # etils + # huggingface-hub + # pytorch-lightning + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scanpy + # scib + # scvi-tools +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via datasets +humanize==4.12.3 + # via orbax-checkpoint +idna==3.10 + # via + # anyio + # requests + # yarl +igraph==0.11.9 + # via + # leidenalg + # scib +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via + # etils + # orbax +ipython==8.27.0 + # via -r requirements.in +itsdangerous==2.2.0 + # via flask +jax==0.6.2 + # via + # chex + # flax + # numpyro + # optax + # orbax + # orbax-checkpoint + # scvi-tools +jaxlib==0.6.2 + # via + # chex + # jax + # numpyro + # optax + # orbax + # scvi-tools +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # dcor + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +leidenalg==0.10.2 + # via + # scgpt + # scib +lightning-utilities==0.15.2 + # via + # pytorch-lightning + # torchmetrics +llvmlite==0.44.0 + # via + # numba + # pynndescent + # scib +mako==1.3.10 + # via alembic +markdown-it-py==4.0.0 + # via rich +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # mlflow + # scanpy + # scib + # seaborn +matplotlib-inline==0.1.7 + # via ipython +mdurl==0.1.2 + # via markdown-it-py +ml-collections==1.1.0 + # via scvi-tools +ml-dtypes==0.5.3 + # via + # jax + # jaxlib + # tensorstore +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via + # flax + # orbax + # orbax-checkpoint +mudata==0.3.2 + # via scvi-tools +multidict==6.6.4 + # via + # aiohttp + # yarl +multipledispatch==1.0.0 + # via numpyro +multiprocess==0.70.15 + # via datasets +natsort==8.4.0 + # via + # anndata + # scanpy +nest-asyncio==1.6.0 + # via + # orbax + # orbax-checkpoint +networkx==3.4.2 + # via + # cell-gears + # scanpy + # torch +numba==0.61.2 + # via + # dcor + # pynndescent + # scanpy + # scgpt + # scib + # umap-learn +numpy==1.26.4 + # via + # anndata + # cell-gears + # chex + # contourpy + # datasets + # dcor + # h5py + # jax + # jaxlib + # matplotlib + # ml-dtypes + # mlflow + # numba + # numpyro + # optax + # orbax + # orbax-checkpoint + # pandas + # patsy + # pyro-ppl + # pytorch-lightning + # scanpy + # scib + # scikit-learn + # scikit-misc + # scipy + # scvi-tools + # seaborn + # statsmodels + # tensorstore + # torchmetrics + # torchtext + # treescope + # umap-learn +numpyro==0.19.0 + # via scvi-tools +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.18.1 + # via torch +nvidia-nvjitlink-cu12==12.9.86 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +openpyxl==3.1.5 + # via scvi-tools +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +opt-einsum==3.4.0 + # via + # jax + # pyro-ppl +optax==0.2.5 + # via + # flax + # scvi-tools +orbax==0.1.7 + # via scgpt +orbax-checkpoint==0.11.21 + # via flax +packaging==25.0 + # via + # anndata + # datasets + # gunicorn + # huggingface-hub + # lightning-utilities + # matplotlib + # mlflow-skinny + # pytorch-lightning + # scanpy + # statsmodels + # torchmetrics +pandas==2.3.1 + # via + # anndata + # cell-gears + # datasets + # mlflow + # scanpy + # scgpt + # scib + # scvi-tools + # seaborn + # statsmodels +parso==0.8.4 + # via jedi +patsy==1.0.1 + # via + # scanpy + # statsmodels +pexpect==4.9.0 + # via ipython +pillow==11.3.0 + # via matplotlib +prompt-toolkit==3.0.51 + # via ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # orbax-checkpoint +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==20.0.0 + # via + # datasets + # mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pydot==4.0.1 + # via scib +pygments==2.19.2 + # via + # ipython + # rich +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via + # matplotlib + # pydot +pyro-api==0.1.2 + # via pyro-ppl +pyro-ppl==1.9.1 + # via scvi-tools +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytorch-lightning==1.9.5 + # via scvi-tools +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # datasets + # flax + # huggingface-hub + # ml-collections + # mlflow-skinny + # orbax + # orbax-checkpoint + # pytorch-lightning +requests==2.32.4 + # via + # databricks-sdk + # datasets + # docker + # huggingface-hub + # mlflow-skinny + # torchdata + # torchtext +rich==14.1.0 + # via + # flax + # scvi-tools +rsa==4.9.1 + # via google-auth +scanpy==1.11.4 + # via + # cell-gears + # scgpt + # scib +scgpt==0.2.1 + # via -r requirements.in +scib==1.1.7 + # via scgpt +scikit-learn==1.7.1 + # via + # cell-gears + # mlflow + # pynndescent + # scanpy + # scib + # scvi-tools + # umap-learn +scikit-misc==0.5.1 + # via + # scgpt + # scib +scipy==1.12.0 + # via + # -r requirements.in + # anndata + # dcor + # jax + # jaxlib + # mlflow + # pynndescent + # scanpy + # scib + # scikit-learn + # scvi-tools + # statsmodels + # umap-learn +scvi-tools==0.20.3 + # via scgpt +seaborn==0.13.2 + # via + # scanpy + # scib +session-info2==0.2 + # via scanpy +setuptools==80.9.0 + # via lightning-utilities +simplejson==3.20.1 + # via orbax-checkpoint +six==1.17.0 + # via + # docrep + # python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +stack-data==0.6.3 + # via ipython +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via scanpy +sympy==1.14.0 + # via torch +tensorstore==0.1.76 + # via + # flax + # orbax + # orbax-checkpoint +texttable==1.7.0 + # via igraph +threadpoolctl==3.6.0 + # via scikit-learn +tomli==2.2.1 + # via alembic +toolz==1.0.0 + # via chex +torch==2.1.2 + # via + # cell-gears + # pyro-ppl + # pytorch-lightning + # scgpt + # scvi-tools + # torchdata + # torchmetrics + # torchtext +torchdata==0.7.1 + # via torchtext +torchmetrics==1.8.1 + # via + # pytorch-lightning + # scvi-tools +torchtext==0.16.2 + # via scgpt +tqdm==4.67.1 + # via + # cell-gears + # datasets + # huggingface-hub + # numpyro + # pyro-ppl + # pytorch-lightning + # scanpy + # scvi-tools + # torchtext + # umap-learn +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +treescope==0.1.10 + # via flax +triton==2.1.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # chex + # etils + # exceptiongroup + # fastapi + # flax + # graphene + # huggingface-hub + # ipython + # lightning-utilities + # mlflow-skinny + # multidict + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # orbax + # orbax-checkpoint + # pydantic + # pydantic-core + # pytorch-lightning + # scanpy + # scgpt + # sqlalchemy + # starlette + # torch + # typing-inspection + # uvicorn +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via + # scanpy + # scgpt + # scib +urllib3==2.5.0 + # via + # docker + # requests + # torchdata +uvicorn==0.35.0 + # via mlflow-skinny +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.3 + # via deprecated +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via + # etils + # importlib-metadata diff --git a/src/methods/scgpt_mlflow/script.py b/src/methods/scgpt_mlflow/script.py new file mode 100644 index 00000000..58978ef4 --- /dev/null +++ b/src/methods/scgpt_mlflow/script.py @@ -0,0 +1,95 @@ +import sys +import anndata as ad +import mlflow + +## VIASH START +par = { + "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", + "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "scgpt_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import train_classifier, classify # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== scGPT (MLflow model) ======", flush=True) + +print("\n>>> Reading training data...", flush=True) +print(f"Training H5AD file: '{par['input_train']}'", flush=True) +input_train = ad.read_h5ad(par["input_train"]) +print(input_train, flush=True) + +if input_train.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"scGPT (MLflow) can only be used with human data " + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' + ) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +# Train classifier on training data +classifier = train_classifier( + input_train, + model, + layers=["counts"], + var={"feature_name": "feature_name"}, + model_params={"gene_col": "feature_name"}, +) + +# Free memory - no longer need training data +del input_train + +print("\n>>> Reading test data...", flush=True) +print(f"Test H5AD file: '{par['input_test']}'", flush=True) +input_test = ad.read_h5ad(par["input_test"]) +print(input_test, flush=True) + +# Store metadata before classifying +dataset_id = input_test.uns["dataset_id"] +normalization_id = input_test.uns["normalization_id"] + +# Classify test data +predictions = classify( + input_test, + model, + classifier, + layers=["counts"], + var={"feature_name": "feature_name"}, + model_params={"gene_col": "feature_name"}, +) + +# Free memory - no longer need test data +del input_test + +print(predictions.value_counts(), flush=True) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs={"label_pred": predictions}, + uns={ + "method_id": meta["name"], + "dataset_id": dataset_id, + "normalization_id": normalization_id, + }, +) +print(output, flush=True) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 753643a5..f44d2b74 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -135,7 +135,7 @@ # the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device model = model.to("cuda" if torch.cuda.is_available() else "cpu") -n_cores = max(16, len(os.sched_getaffinity(0))) +n_cores = min(len(os.sched_getaffinity(0)), 24) print(f"Using {n_cores} worker cores") embedder = scprint.tasks.Embedder( diff --git a/src/methods/scvi_mlflow/config.vsh.yaml b/src/methods/scvi_mlflow/config.vsh.yaml new file mode 100644 index 00000000..7ebc838a --- /dev/null +++ b/src/methods/scvi_mlflow/config.vsh.yaml @@ -0,0 +1,49 @@ +__merge__: ../../api/base_method.yaml + +name: scvi_mlflow +label: scVI (MLflow) +summary: scVI combines a variational autoencoder with a hierarchical Bayesian model (MLflow model) +description: | + scVI combines a variational autoencoder with a hierarchical Bayesian model. + It uses the negative binomial distribution to describe gene expression of + each cell, conditioned on unobserved factors and the batch variable. + + This version uses a pre-trained MLflow model. A kNN classifier is trained on + embeddings for the training data and used to predict labels for the test + data. +references: + doi: + - 10.1038/s41592-018-0229-2 +links: + repository: https://github.com/scverse/scvi-tools + documentation: https://docs.scvi-tools.org/en/stable/user_guide/models/scvi.html + +info: + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the scVI model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/scvi_mlflow/requirements.txt b/src/methods/scvi_mlflow/requirements.txt new file mode 100644 index 00000000..c3c79df5 --- /dev/null +++ b/src/methods/scvi_mlflow/requirements.txt @@ -0,0 +1,459 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmp6b02zuzi/requirements_initial.txt +absl-py==2.3.1 + # via + # chex + # ml-collections + # optax + # orbax-checkpoint +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +alembic==1.16.4 + # via mlflow +anndata==0.10.8 + # via + # -r requirements.in + # mudata + # scvi-tools +annotated-types==0.7.0 + # via pydantic +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via aiohttp +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +chex==0.1.90 + # via optax +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +docker==7.1.0 + # via mlflow +docrep==0.3.2 + # via scvi-tools +etils==1.13.0 + # via orbax-checkpoint +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # torch + # triton +flask==3.1.1 + # via mlflow +flax==0.10.4 + # via scvi-tools +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.7.0 + # via + # etils + # lightning + # pytorch-lightning + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scvi-tools +humanize==4.12.3 + # via orbax-checkpoint +idna==3.10 + # via + # anyio + # requests + # yarl +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via etils +itsdangerous==2.2.0 + # via flask +jax==0.4.33 + # via + # -r requirements.in + # chex + # flax + # numpyro + # optax + # orbax-checkpoint + # scvi-tools +jaxlib==0.4.33 + # via + # -r requirements.in + # chex + # jax + # numpyro + # optax + # orbax-checkpoint + # scvi-tools +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via scikit-learn +kiwisolver==1.4.9 + # via matplotlib +lightning==2.5.2 + # via scvi-tools +lightning-utilities==0.15.2 + # via + # lightning + # pytorch-lightning + # torchmetrics +mako==1.3.10 + # via alembic +markdown-it-py==4.0.0 + # via rich +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via mlflow +mdurl==0.1.2 + # via markdown-it-py +ml-collections==1.1.0 + # via scvi-tools +ml-dtypes==0.5.3 + # via + # jax + # jaxlib + # tensorstore +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +msgpack==1.1.1 + # via + # flax + # orbax-checkpoint +mudata==0.3.2 + # via scvi-tools +multidict==6.6.4 + # via + # aiohttp + # yarl +multipledispatch==1.0.0 + # via numpyro +natsort==8.4.0 + # via anndata +nest-asyncio==1.6.0 + # via orbax-checkpoint +networkx==3.5 + # via torch +numpy==1.26.4 + # via + # anndata + # chex + # contourpy + # flax + # h5py + # jax + # jaxlib + # matplotlib + # ml-dtypes + # mlflow + # numpyro + # optax + # orbax-checkpoint + # pandas + # pyro-ppl + # scikit-learn + # scipy + # scvi-tools + # tensorstore + # torchmetrics + # treescope +numpyro==0.19.0 + # via scvi-tools +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +opt-einsum==3.4.0 + # via + # jax + # pyro-ppl +optax==0.2.5 + # via + # flax + # scvi-tools +orbax-checkpoint==0.6.4 + # via flax +packaging==25.0 + # via + # anndata + # gunicorn + # lightning + # lightning-utilities + # matplotlib + # mlflow-skinny + # pytorch-lightning + # torchmetrics +pandas==2.2.3 + # via + # -r requirements.in + # anndata + # mlflow + # scvi-tools +pillow==11.3.0 + # via matplotlib +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.1 + # via + # mlflow-skinny + # orbax-checkpoint +pyarrow==20.0.0 + # via mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pygments==2.19.2 + # via rich +pyparsing==3.2.3 + # via matplotlib +pyro-api==0.1.2 + # via pyro-ppl +pyro-ppl==1.9.1 + # via scvi-tools +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytorch-lightning==2.5.2 + # via lightning +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # flax + # lightning + # ml-collections + # mlflow-skinny + # orbax-checkpoint + # pytorch-lightning +requests==2.32.4 + # via + # databricks-sdk + # docker + # mlflow-skinny +rich==14.1.0 + # via + # flax + # scvi-tools +rsa==4.9.1 + # via google-auth +scikit-learn==1.7.1 + # via + # mlflow + # scvi-tools +scipy==1.16.1 + # via + # anndata + # jax + # jaxlib + # mlflow + # scikit-learn + # scvi-tools +scvi-tools==1.1.6.post2 + # via -r requirements.in +setuptools==80.9.0 + # via lightning-utilities +six==1.17.0 + # via + # docrep + # python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +sympy==1.13.1 + # via torch +tensorstore==0.1.76 + # via + # flax + # orbax-checkpoint +threadpoolctl==3.6.0 + # via scikit-learn +toolz==1.0.0 + # via chex +torch==2.5.1 + # via + # -r requirements.in + # lightning + # pyro-ppl + # pytorch-lightning + # scvi-tools + # torchmetrics +torchmetrics==1.8.1 + # via + # lightning + # pytorch-lightning + # scvi-tools +tqdm==4.67.1 + # via + # lightning + # numpyro + # pyro-ppl + # pytorch-lightning + # scvi-tools +treescope==0.1.10 + # via flax +triton==3.1.0 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # alembic + # anyio + # chex + # etils + # fastapi + # flax + # graphene + # lightning + # lightning-utilities + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # orbax-checkpoint + # pydantic + # pydantic-core + # pytorch-lightning + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +urllib3==2.5.0 + # via + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via flask +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via + # etils + # importlib-metadata diff --git a/src/methods/scvi_mlflow/script.py b/src/methods/scvi_mlflow/script.py new file mode 100644 index 00000000..d587cf20 --- /dev/null +++ b/src/methods/scvi_mlflow/script.py @@ -0,0 +1,99 @@ +import sys +import anndata as ad +import mlflow + +## VIASH START +par = { + "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", + "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "scvi_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import train_classifier, classify # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== scVI (MLflow model) ======", flush=True) + +print("\n>>> Reading training data...", flush=True) +print(f"Training H5AD file: '{par['input_train']}'", flush=True) +input_train = ad.read_h5ad(par["input_train"]) +print(input_train, flush=True) + +if input_train.uns["dataset_organism"] == "homo_sapiens": + organism = "human" +elif input_train.uns["dataset_organism"] == "mus_musculus": + organism = "mouse" +else: + exit_non_applicable( + f"scVI (MLflow) can only be used with human or mouse data " + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' + ) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print(f"\n>>> Loading {organism} model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir, model_config={"organism": organism}) +print(model, flush=True) + +# Train classifier on training data +classifier = train_classifier( + input_train, + model, + layers=["counts"], + obs=["batch"], + var={"feature_id": "feature_id"}, +) + +# Free memory - no longer need training data +del input_train + +print("\n>>> Reading test data...", flush=True) +print(f"Test H5AD file: '{par['input_test']}'", flush=True) +input_test = ad.read_h5ad(par["input_test"]) +print(input_test, flush=True) + +# Store metadata before classifying +dataset_id = input_test.uns["dataset_id"] +normalization_id = input_test.uns["normalization_id"] + +# Classify test data +predictions = classify( + input_test, + model, + classifier, + layers=["counts"], + obs=["batch"], + var={"feature_id": "feature_id"}, +) + +# Free memory - no longer need test data +del input_test + +print(predictions.value_counts(), flush=True) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs={"label_pred": predictions}, + uns={ + "method_id": meta["name"], + "dataset_id": dataset_id, + "normalization_id": normalization_id, + }, +) +print(output, flush=True) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/transcriptformer_mlflow/config.vsh.yaml b/src/methods/transcriptformer_mlflow/config.vsh.yaml new file mode 100644 index 00000000..faa2a572 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/config.vsh.yaml @@ -0,0 +1,53 @@ +__merge__: ../../api/base_method.yaml + +name: transcriptformer_mlflow +label: TranscriptFormer (MLflow) +summary: Context-aware representations of single-cell transcriptomes (MLflow model) +description: | + TranscriptFormer is designed to learn rich, context-aware representations of + single-cell transcriptomes while jointly modeling genes and transcripts using + a novel generative architecture. + + It is a family of generative foundation models representing a cross-species + generative cell atlas trained on up to 112 million cells spanning 1.53 billion + years of evolution across 12 species. + + This version uses a pre-trained MLflow model. A kNN classifier is trained on + embeddings for the training data and used to predict labels for the test + data. +references: + doi: + - 10.1101/2025.04.25.650731 +links: + documentation: https://github.com/czi-ai/transcriptformer#readme + repository: https://github.com/czi-ai/transcriptformer + +info: + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the TranscriptFormer model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/transcriptformer_mlflow/requirements.txt b/src/methods/transcriptformer_mlflow/requirements.txt new file mode 100644 index 00000000..70d923d1 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/requirements.txt @@ -0,0 +1,338 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o requirements.txt +aiobotocore==2.23.0 + # via s3fs +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.13 + # via + # aiobotocore + # fsspec + # s3fs +aioitertools==0.12.0 + # via aiobotocore +aiosignal==1.3.2 + # via aiohttp +anndata==0.11.4 + # via + # cellxgene-census + # scanpy + # somacore + # tiledbsoma + # transcriptformer +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +array-api-compat==1.12.0 + # via anndata +attrs==25.3.0 + # via + # aiohttp + # somacore + # tiledbsoma +boto3==1.38.27 + # via transcriptformer +botocore==1.38.27 + # via + # aiobotocore + # boto3 + # s3transfer +cellxgene-census==1.17.0 + # via transcriptformer +certifi==2025.6.15 + # via requests +charset-normalizer==3.4.2 + # via requests +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +filelock==3.18.0 + # via + # torch + # triton +fonttools==4.58.4 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec==2025.5.1 + # via + # pytorch-lightning + # s3fs + # torch +h5py==3.14.0 + # via + # anndata + # scanpy + # transcriptformer +hydra-core==1.3.2 + # via transcriptformer +idna==3.10 + # via + # requests + # yarl +iniconfig==2.1.0 + # via pytest +jinja2==3.1.6 + # via torch +jmespath==1.0.1 + # via + # aiobotocore + # boto3 + # botocore +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.8 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +lightning-utilities==0.14.3 + # via + # pytorch-lightning + # torchmetrics +llvmlite==0.44.0 + # via + # numba + # pynndescent +markupsafe==3.0.2 + # via jinja2 +matplotlib==3.10.3 + # via + # scanpy + # seaborn +more-itertools==10.7.0 + # via tiledbsoma +mpmath==1.3.0 + # via sympy +multidict==6.6.0 + # via + # aiobotocore + # aiohttp + # yarl +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # pynndescent + # scanpy + # umap-learn +numpy==2.2.6 + # via + # anndata + # cellxgene-census + # contourpy + # h5py + # matplotlib + # numba + # pandas + # patsy + # scanpy + # scikit-learn + # scipy + # seaborn + # shapely + # somacore + # statsmodels + # tiledbsoma + # torchmetrics + # transcriptformer + # umap-learn +nvidia-cublas-cu12==12.4.5.8 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.4.127 + # via torch +nvidia-cuda-nvrtc-cu12==12.4.127 + # via torch +nvidia-cuda-runtime-cu12==12.4.127 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.2.1.3 + # via torch +nvidia-curand-cu12==10.3.5.147 + # via torch +nvidia-cusolver-cu12==11.6.1.9 + # via torch +nvidia-cusparse-cu12==12.3.1.170 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-ml-py==12.575.51 + # via pynvml +nvidia-nccl-cu12==2.21.5 + # via torch +nvidia-nvjitlink-cu12==12.4.127 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.4.127 + # via torch +omegaconf==2.3.0 + # via hydra-core +packaging==25.0 + # via + # anndata + # hydra-core + # lightning-utilities + # matplotlib + # pytest + # pytorch-lightning + # scanpy + # statsmodels + # torchmetrics +pandas==2.3.0 + # via + # anndata + # scanpy + # seaborn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer +patsy==1.0.1 + # via + # scanpy + # statsmodels +pillow==11.2.1 + # via matplotlib +pluggy==1.6.0 + # via pytest +propcache==0.3.2 + # via + # aiohttp + # yarl +psutil==7.0.0 + # via transcriptformer +pyarrow==20.0.0 + # via + # somacore + # tiledbsoma +pyarrow-hotfix==0.7 + # via somacore +pygments==2.19.2 + # via pytest +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pynvml==12.0.0 + # via transcriptformer +pyparsing==3.2.3 + # via matplotlib +pytest==8.4.1 + # via transcriptformer +python-dateutil==2.9.0.post0 + # via + # aiobotocore + # botocore + # matplotlib + # pandas +pytorch-lightning==2.5.2 + # via transcriptformer +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # omegaconf + # pytorch-lightning +requests==2.32.4 + # via cellxgene-census +s3fs==2025.5.1 + # via cellxgene-census +s3transfer==0.13.0 + # via boto3 +scanpy==1.11.2 + # via + # tiledbsoma + # transcriptformer +scikit-learn==1.7.0 + # via + # pynndescent + # scanpy + # umap-learn +scipy==1.16.0 + # via + # anndata + # pynndescent + # scanpy + # scikit-learn + # somacore + # statsmodels + # tiledbsoma + # transcriptformer + # umap-learn +seaborn==0.13.2 + # via scanpy +session-info2==0.1.2 + # via scanpy +setuptools==80.9.0 + # via lightning-utilities +shapely==2.1.1 + # via somacore +six==1.17.0 + # via python-dateutil +somacore==1.0.28 + # via tiledbsoma +statsmodels==0.14.4 + # via scanpy +sympy==1.13.1 + # via torch +threadpoolctl==3.6.0 + # via scikit-learn +tiledbsoma==1.17.0 + # via cellxgene-census +timeout-decorator==0.5.0 + # via transcriptformer +torch==2.5.1 + # via + # pytorch-lightning + # torchmetrics + # transcriptformer +torchmetrics==1.7.3 + # via pytorch-lightning +tqdm==4.67.1 + # via + # pytorch-lightning + # scanpy + # umap-learn +transcriptformer==0.3.0 + # via -r requirements.in +triton==3.1.0 + # via torch +typing-extensions==4.14.0 + # via + # cellxgene-census + # lightning-utilities + # pytorch-lightning + # scanpy + # somacore + # tiledbsoma + # torch +tzdata==2025.2 + # via pandas +umap-learn==0.5.7 + # via scanpy +urllib3==2.5.0 + # via + # botocore + # requests +wrapt==1.17.2 + # via aiobotocore +yarl==1.20.1 + # via aiohttp diff --git a/src/methods/transcriptformer_mlflow/script.py b/src/methods/transcriptformer_mlflow/script.py new file mode 100644 index 00000000..c3775d17 --- /dev/null +++ b/src/methods/transcriptformer_mlflow/script.py @@ -0,0 +1,104 @@ +import sys +import anndata as ad +import mlflow + +## VIASH START +par = { + "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", + "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "transcriptformer_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import train_classifier, classify # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== TranscriptFormer (MLflow model) ======", flush=True) + +print("\n>>> Reading training data...", flush=True) +print(f"Training H5AD file: '{par['input_train']}'", flush=True) +input_train = ad.read_h5ad(par["input_train"]) +print(input_train, flush=True) + +if input_train.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"TranscriptFormer (MLflow) can only be used with human data " + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' + ) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + + +def process_transcriptformer_adata(adata): + """Add assay column if missing (TranscriptFormer requires it).""" + if "assay" not in adata.obs or adata.obs["assay"].isna().all(): + adata.obs["assay"] = "unknown" + + +# Train classifier on training data +classifier = train_classifier( + input_train, + model, + layers=["counts"], + obs=["assay"], + var={"feature_id": "ensembl_id"}, + process_adata=process_transcriptformer_adata, +) + +# Free memory - no longer need training data +del input_train + +print("\n>>> Reading test data...", flush=True) +print(f"Test H5AD file: '{par['input_test']}'", flush=True) +input_test = ad.read_h5ad(par["input_test"]) +print(input_test, flush=True) + +# Store metadata before classifying +dataset_id = input_test.uns["dataset_id"] +normalization_id = input_test.uns["normalization_id"] + +# Classify test data +predictions = classify( + input_test, + model, + classifier, + layers=["counts"], + obs=["assay"], + var={"feature_id": "ensembl_id"}, + process_adata=process_transcriptformer_adata, +) + +# Free memory - no longer need test data +del input_test + +print(predictions.value_counts(), flush=True) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs={"label_pred": predictions}, + uns={ + "method_id": meta["name"], + "dataset_id": dataset_id, + "normalization_id": normalization_id, + }, +) +print(output, flush=True) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/uce_mlflow/config.vsh.yaml b/src/methods/uce_mlflow/config.vsh.yaml new file mode 100644 index 00000000..cd125645 --- /dev/null +++ b/src/methods/uce_mlflow/config.vsh.yaml @@ -0,0 +1,49 @@ +__merge__: ../../api/base_method.yaml + +name: uce_mlflow +label: UCE (MLflow) +summary: UCE offers a unified biological latent space that can represent any cell (MLflow model) +description: | + Universal Cell Embedding (UCE) is a single-cell foundation model that offers a + unified biological latent space that can represent any cell, regardless of + tissue or species. + + This version uses a pre-trained MLflow model. A kNN classifier is trained on + embeddings for the training data and used to predict labels for the test + data. +references: + doi: + - 10.1101/2023.11.28.568918 +links: + documentation: https://github.com/snap-stanford/UCE/blob/main/README.md + repository: https://github.com/snap-stanford/UCE + +info: + preferred_normalization: counts + +arguments: + - name: --model + type: file + description: | + An MLflow model URI for the UCE model. If it is a .zip or + .tar.gz file it will be extracted to a temporary directory. + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/exit_codes.py + - path: /src/utils/unpack.py + - path: /src/utils/mlflow.py + - path: requirements.txt + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1 + __merge__: /src/utils/mlflow_docker_setup.yaml + +runners: + - type: executable + - type: nextflow + directives: + label: [hightime, highmem, midcpu, gpu] diff --git a/src/methods/uce_mlflow/requirements.txt b/src/methods/uce_mlflow/requirements.txt new file mode 100644 index 00000000..b2f4227b --- /dev/null +++ b/src/methods/uce_mlflow/requirements.txt @@ -0,0 +1,366 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in -o /tmp/tmpg2ov1w_7/requirements_initial.txt +accelerate==0.34.2 + # via -r requirements.in +alembic==1.16.4 + # via mlflow +anndata==0.10.9 + # via + # -r requirements.in + # scanpy +annotated-types==0.7.0 + # via pydantic +antlr4-python3-runtime==4.9.3 + # via omegaconf +anyio==4.10.0 + # via starlette +array-api-compat==1.12.0 + # via anndata +blinker==1.9.0 + # via flask +cachetools==5.5.2 + # via + # google-auth + # mlflow-skinny +certifi==2025.8.3 + # via requests +charset-normalizer==3.4.3 + # via requests +click==8.2.1 + # via + # flask + # mlflow-skinny + # uvicorn +cloudpickle==3.1.1 + # via mlflow-skinny +contourpy==1.3.3 + # via matplotlib +cycler==0.12.1 + # via matplotlib +databricks-sdk==0.62.0 + # via mlflow-skinny +docker==7.1.0 + # via mlflow +fastapi==0.116.1 + # via mlflow-skinny +filelock==3.18.0 + # via + # huggingface-hub + # torch + # triton +flask==3.1.1 + # via mlflow +fonttools==4.59.0 + # via matplotlib +fsspec==2025.7.0 + # via + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via mlflow-skinny +google-auth==2.40.3 + # via databricks-sdk +graphene==3.4.3 + # via mlflow +graphql-core==3.2.6 + # via + # graphene + # graphql-relay +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.4 + # via sqlalchemy +gunicorn==23.0.0 + # via mlflow +h11==0.16.0 + # via uvicorn +h5py==3.14.0 + # via + # anndata + # scanpy +hf-xet==1.1.7 + # via huggingface-hub +huggingface-hub==0.34.4 + # via accelerate +idna==3.10 + # via + # anyio + # requests +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +itsdangerous==2.2.0 + # via flask +jinja2==3.1.6 + # via + # flask + # torch +joblib==1.5.1 + # via + # pynndescent + # scanpy + # scikit-learn +kiwisolver==1.4.9 + # via matplotlib +legacy-api-wrap==1.4.1 + # via scanpy +llvmlite==0.44.0 + # via + # numba + # pynndescent +mako==1.3.10 + # via alembic +markupsafe==3.0.2 + # via + # flask + # jinja2 + # mako + # werkzeug +matplotlib==3.10.5 + # via + # mlflow + # scanpy + # seaborn +mlflow==3.1.0 + # via -r requirements.in +mlflow-skinny==3.1.0 + # via mlflow +mpmath==1.3.0 + # via sympy +natsort==8.4.0 + # via + # anndata + # scanpy +networkx==3.5 + # via + # scanpy + # torch +numba==0.61.2 + # via + # pynndescent + # scanpy + # umap-learn +numpy==1.26.4 + # via + # -r requirements.in + # accelerate + # anndata + # contourpy + # h5py + # matplotlib + # mlflow + # numba + # pandas + # patsy + # scanpy + # scikit-learn + # scipy + # seaborn + # statsmodels + # umap-learn +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==9.1.0.70 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.9.86 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +omegaconf==2.3.0 + # via -r requirements.in +opentelemetry-api==1.36.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.36.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.57b0 + # via opentelemetry-sdk +packaging==25.0 + # via + # accelerate + # anndata + # gunicorn + # huggingface-hub + # matplotlib + # mlflow-skinny + # scanpy + # statsmodels +pandas==2.2.3 + # via + # -r requirements.in + # anndata + # mlflow + # scanpy + # seaborn + # statsmodels +patsy==1.0.1 + # via + # scanpy + # statsmodels +pillow==11.3.0 + # via matplotlib +protobuf==6.31.1 + # via mlflow-skinny +psutil==7.0.0 + # via accelerate +pyarrow==20.0.0 + # via mlflow +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth +pydantic==2.11.7 + # via + # fastapi + # mlflow-skinny +pydantic-core==2.33.2 + # via pydantic +pynndescent==0.5.13 + # via + # scanpy + # umap-learn +pyparsing==3.2.3 + # via matplotlib +python-dateutil==2.9.0.post0 + # via + # graphene + # matplotlib + # pandas +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # huggingface-hub + # mlflow-skinny + # omegaconf +requests==2.32.4 + # via + # databricks-sdk + # docker + # huggingface-hub + # mlflow-skinny +rsa==4.9.1 + # via google-auth +safetensors==0.6.2 + # via accelerate +scanpy==1.10.2 + # via -r requirements.in +scikit-learn==1.7.1 + # via + # mlflow + # pynndescent + # scanpy + # umap-learn +scipy==1.14.1 + # via + # -r requirements.in + # anndata + # mlflow + # pynndescent + # scanpy + # scikit-learn + # statsmodels + # umap-learn +seaborn==0.13.2 + # via scanpy +session-info==1.0.1 + # via scanpy +six==1.17.0 + # via python-dateutil +smmap==5.0.2 + # via gitdb +sniffio==1.3.1 + # via anyio +sqlalchemy==2.0.43 + # via + # alembic + # mlflow +sqlparse==0.5.3 + # via mlflow-skinny +starlette==0.47.2 + # via fastapi +statsmodels==0.14.5 + # via scanpy +stdlib-list==0.11.1 + # via session-info +sympy==1.14.0 + # via torch +threadpoolctl==3.6.0 + # via scikit-learn +torch==2.4.1 + # via + # -r requirements.in + # accelerate +tqdm==4.66.5 + # via + # -r requirements.in + # huggingface-hub + # scanpy + # umap-learn +triton==3.0.0 + # via torch +typing-extensions==4.14.1 + # via + # alembic + # anyio + # fastapi + # graphene + # huggingface-hub + # mlflow-skinny + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions + # pydantic + # pydantic-core + # sqlalchemy + # starlette + # torch + # typing-inspection +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +umap-learn==0.5.9.post2 + # via scanpy +urllib3==1.26.6 + # via + # -r requirements.in + # docker + # requests +uvicorn==0.35.0 + # via mlflow-skinny +werkzeug==3.1.3 + # via flask +zipp==3.23.0 + # via importlib-metadata diff --git a/src/methods/uce_mlflow/script.py b/src/methods/uce_mlflow/script.py new file mode 100644 index 00000000..bb3716d6 --- /dev/null +++ b/src/methods/uce_mlflow/script.py @@ -0,0 +1,93 @@ +import sys +import anndata as ad +import mlflow + +## VIASH START +par = { + "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", + "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", + "output": "output.h5ad", + "model": "resources_test/.../model", +} +meta = {"name": "uce_mlflow"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable # noqa: E402 +from mlflow import train_classifier, classify # noqa: E402 +from unpack import unpack_directory # noqa: E402 + +print("====== UCE (MLflow model) ======", flush=True) + +print("\n>>> Reading training data...", flush=True) +print(f"Training H5AD file: '{par['input_train']}'", flush=True) +input_train = ad.read_h5ad(par["input_train"]) +print(input_train, flush=True) + +if input_train.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"UCE (MLflow) can only be used with human data " + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' + ) + +print("\n>>> Unpacking model...", flush=True) +model_dir, model_temp = unpack_directory(par["model"]) + +print("\n>>> Loading model...", flush=True) +model = mlflow.pyfunc.load_model(model_dir) +print(model, flush=True) + +# Train classifier on training data +classifier = train_classifier( + input_train, + model, + layers=["counts"], + var={"feature_name": "feature_name"}, +) + +# Free memory - no longer need training data +del input_train + +print("\n>>> Reading test data...", flush=True) +print(f"Test H5AD file: '{par['input_test']}'", flush=True) +input_test = ad.read_h5ad(par["input_test"]) +print(input_test, flush=True) + +# Store metadata before classifying +dataset_id = input_test.uns["dataset_id"] +normalization_id = input_test.uns["normalization_id"] + +# Classify test data +predictions = classify( + input_test, + model, + classifier, + layers=["counts"], + var={"feature_name": "feature_name"}, +) + +# Free memory - no longer need test data +del input_test + +print(predictions.value_counts(), flush=True) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs={"label_pred": predictions}, + uns={ + "method_id": meta["name"], + "dataset_id": dataset_id, + "normalization_id": normalization_id, + }, +) +print(output, flush=True) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary files...", flush=True) +if model_temp is not None: + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/utils/mlflow.py b/src/utils/mlflow.py new file mode 100644 index 00000000..9aa95996 --- /dev/null +++ b/src/utils/mlflow.py @@ -0,0 +1,201 @@ +""" +Common utilities for MLflow-based methods. +""" +import os +import tempfile + +import anndata as ad +import pandas as pd +import sklearn.neighbors + + +def create_temp_h5ad( + adata, layers=None, obs=None, var=None, obsm=None, varm=None, uns=None +): + """ + Create a temporary H5AD file with specified data from an AnnData object. + + Args: + adata: Input AnnData object + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + obsm: List of obsm keys to include + varm: List of varm keys to include + uns: List of uns keys to include + + Returns: + tuple: (h5ad_file, input_adata) where h5ad_file is the NamedTemporaryFile and + input_adata is the created AnnData object + """ + # Extract X from layers or use X directly + if layers and len(layers) > 0: + X = adata.layers[layers[0]].copy() + else: + X = adata.X.copy() + + # Create new AnnData + input_adata = ad.AnnData(X=X) + + # Set var_names + input_adata.var_names = adata.var_names + + # Add obs columns + if obs: + for obs_key in obs: + if obs_key in adata.obs: + input_adata.obs[obs_key] = adata.obs[obs_key].values + + # Add var columns (with optional renaming) + if var: + for old_name, new_name in var.items(): + if old_name in adata.var: + input_adata.var[new_name] = adata.var[old_name].values + + # Add obsm + if obsm: + for obsm_key in obsm: + if obsm_key in adata.obsm: + input_adata.obsm[obsm_key] = adata.obsm[obsm_key].copy() + + # Add varm + if varm: + for varm_key in varm: + if varm_key in adata.varm: + input_adata.varm[varm_key] = adata.varm[varm_key].copy() + + # Add uns + if uns: + for uns_key in uns: + if uns_key in adata.uns: + input_adata.uns[uns_key] = adata.uns[uns_key] + + # Write to temp file + h5ad_file = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) + input_adata.write(h5ad_file.name) + + return h5ad_file, input_adata + + +def embed(adata, model, layers=None, obs=None, var=None, model_params=None, process_adata=None): + """ + Embed data using an MLflow model. + + Args: + adata: Input AnnData object to embed + model: Loaded MLflow model + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + model_params: Optional dict of parameters to pass to model.predict() + process_adata: Optional function to process input_adata before writing (e.g., to add defaults) + + Returns: + np.ndarray: Embeddings for the input data + """ + print("Writing temporary input H5AD file...", flush=True) + h5ad_file, input_adata = create_temp_h5ad(adata, layers=layers, obs=obs, var=var) + + # Apply any post-processing to input_adata + if process_adata: + process_adata(input_adata) + + print(f"Temporary H5AD file: '{h5ad_file.name}'", flush=True) + print(input_adata, flush=True) + + # Re-write the file after processing + input_adata.write(h5ad_file.name) + + print("Running model...", flush=True) + input_df = pd.DataFrame({"input_uri": [h5ad_file.name]}) + if model_params: + embedding = model.predict(input_df, params=model_params) + else: + embedding = model.predict(input_df) + + # Clean up + h5ad_file.close() + os.unlink(h5ad_file.name) + + return embedding + + +def train_classifier( + train_adata, + model, + layers=None, + obs=None, + var=None, + model_params=None, + process_adata=None, + n_neighbors=5, +): + """ + Embed training data and train a kNN classifier. + + Args: + train_adata: Training AnnData object with labels + model: Loaded MLflow model + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + model_params: Optional dict of parameters to pass to model.predict() + process_adata: Optional function to process input_adata before writing (e.g., to add defaults) + n_neighbors: Number of neighbors for kNN classifier + + Returns: + sklearn.neighbors.KNeighborsClassifier: Trained classifier + """ + # Embed training data + print("\n>>> Embedding training data...", flush=True) + embedding_train = embed( + train_adata, model, layers=layers, obs=obs, var=var, + model_params=model_params, process_adata=process_adata + ) + + # Train kNN classifier + print("\n>>> Training kNN classifier...", flush=True) + classifier = sklearn.neighbors.KNeighborsClassifier(n_neighbors=n_neighbors) + classifier.fit(embedding_train, train_adata.obs["label"].astype(str)) + + return classifier + + +def classify( + test_adata, + model, + classifier, + layers=None, + obs=None, + var=None, + model_params=None, + process_adata=None, +): + """ + Embed test data and classify using a trained classifier. + + Args: + test_adata: Test AnnData object to predict + model: Loaded MLflow model + classifier: Trained sklearn classifier + layers: List of layer names to include (e.g., ["counts"]) + obs: List of obs column names to include (e.g., ["batch"]) + var: Dict mapping var column names to new names (e.g., {"feature_id": "ensembl_id"}) + model_params: Optional dict of parameters to pass to model.predict() + process_adata: Optional function to process input_adata before writing (e.g., to add defaults) + + Returns: + pd.Series: Predicted labels for test data + """ + # Embed test data + print("\n>>> Embedding test data...", flush=True) + embedding_test = embed( + test_adata, model, layers=layers, obs=obs, var=var, + model_params=model_params, process_adata=process_adata + ) + + # Classify + print("\n>>> Classifying test data...", flush=True) + predictions = classifier.predict(embedding_test) + + return pd.Series(predictions, index=test_adata.obs_names) diff --git a/src/utils/mlflow_docker_setup.yaml b/src/utils/mlflow_docker_setup.yaml new file mode 100644 index 00000000..aa03e9a7 --- /dev/null +++ b/src/utils/mlflow_docker_setup.yaml @@ -0,0 +1,14 @@ +- type: docker + add: https://astral.sh/uv/0.7.19/install.sh /uv-installer.sh + run: sh /uv-installer.sh && rm /uv-installer.sh + env: PATH="/root/.local/bin/:$PATH" +- type: docker + run: uv venv --python 3.11 /opt/venv +- type: docker + env: + - VIRTUAL_ENV=/opt/venv + - PATH="/opt/venv/bin:$PATH" + add: requirements.txt /requirements.txt + run: uv pip install -r /requirements.txt && uv pip install mlflow==3.1.0 +- type: docker + run: uv pip install git+https://github.com/openproblems-bio/core#subdirectory=packages/python/openproblems diff --git a/src/utils/unpack.py b/src/utils/unpack.py new file mode 100644 index 00000000..443aa39f --- /dev/null +++ b/src/utils/unpack.py @@ -0,0 +1,43 @@ +import os +import tarfile +import tempfile +import zipfile + +def unpack_directory(directory): + """ + Unpack a directory to a temporary location (if needed) + + Args: + directory (str): Path to a directory, .zip, or .tar.gz file. + + Returns: + tuple: (unpacked_directory (str), temp_directory (TemporaryDirectory or None)) + unpacked_directory: Path to the unpacked directory. + temp_directory: TemporaryDirectory object if a temp dir was created, else None. + """ + print(f"Unpacking directory: '{directory}'", flush=True) + + if os.path.isdir(directory): + print(f"Returning provided directory: '{directory}'", flush=True) + temp_directory = None + unpacked_directory = directory + else: + temp_directory = tempfile.TemporaryDirectory() + unpacked_directory = temp_directory.name + + if zipfile.is_zipfile(directory): + print("Extracting .zip...", flush=True) + with zipfile.ZipFile(directory, "r") as zip_file: + zip_file.extractall(unpacked_directory) + elif tarfile.is_tarfile(directory) and directory.endswith(".tar.gz"): + print("Extracting .tar.gz...", flush=True) + with tarfile.open(directory, "r:gz") as tar_file: + tar_file.extractall(unpacked_directory) + unpacked_directory = os.path.join(unpacked_directory, os.listdir(unpacked_directory)[0]) + else: + raise ValueError( + "The 'directory' argument should be a directory, a .zip file or a .tar.gz file" + ) + print(f"Extracted to '{unpacked_directory}'", flush=True) + + return (unpacked_directory, temp_directory) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 3680c6ad..abd8eb49 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -80,24 +80,30 @@ dependencies: - name: control_methods/majority_vote - name: control_methods/random_labels - name: control_methods/true_labels - - name: methods/geneformer + - name: methods/cellmapper_linear + - name: methods/cellmapper_scvi + - name: methods/geneformer_mlflow - name: methods/knn - name: methods/logistic_regression - name: methods/mlp - name: methods/naive_bayes - name: methods/scanvi - name: methods/scanvi_scarches + - name: methods/scgpt_mlflow - name: methods/scgpt_finetuned - name: methods/scgpt_zeroshot - name: methods/scimilarity - name: methods/scimilarity_knn - name: methods/scprint + - name: methods/scvi_mlflow - name: methods/seurat_transferdata - name: methods/singler - - name: methods/xgboost + - name: methods/transcriptformer_mlflow - name: methods/uce - - name: methods/cellmapper_linear - - name: methods/cellmapper_scvi + - name: methods/uce_mlflow + - name: methods/xgboost + # always fails: + # - name: methods/geneformer - name: metrics/accuracy - name: metrics/f1 diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 26bc2733..52398574 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -11,7 +11,10 @@ methods = [ majority_vote, random_labels, true_labels, - geneformer, + // geneformer, + geneformer_mlflow.run( + args: [model: file("s3://openproblems-work/cache/geneformer-mlflow-model.zip")] + ), knn, logistic_regression, mlp, @@ -20,10 +23,12 @@ methods = [ scanvi_scarches, cellmapper_linear, cellmapper_scvi, - scgpt_finetuned.run( args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] ), + scgpt_mlflow.run( + args: [model: file("s3://openproblems-work/cache/scgpt-mlflow-model.zip")] + ), scgpt_zeroshot.run( args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] ), @@ -34,11 +39,20 @@ methods = [ args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ), scprint, + scvi_mlflow.run( + args: [model: file("s3://openproblems-work/cache/scvi-mlflow-model.zip")] + ), seurat_transferdata, singler, + transcriptformer_mlflow.run( + args: [model: file("s3://openproblems-work/cache/transcriptformer-mlflow-model.zip")] + ), uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")] ), + uce_mlflow.run( + args: [model: file("s3://openproblems-work/cache/uce-mlflow-model.zip")] + ), xgboost ]