Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10,657 changes: 10,657 additions & 0 deletions docs/tutorials/3-finetune.html

Large diffs are not rendered by default.

2,133 changes: 1,842 additions & 291 deletions docs/tutorials/3-finetune.ipynb

Large diffs are not rendered by default.

Binary file added docs/tutorials/data/data.h5ad
Binary file not shown.
5 changes: 4 additions & 1 deletion src/decima/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import click

from decima.cli.predict_genes import cli_predict_genes
from decima.cli.download import cli_download
from decima.cli.download import cli_cache, cli_download_weights, cli_download_metadata, cli_download
from decima.cli.attributions import (
cli_attributions,
cli_attributions_plot,
Expand Down Expand Up @@ -40,6 +40,9 @@ def main():


main.add_command(cli_predict_genes, name="predict-genes")
main.add_command(cli_cache, name="cache")
main.add_command(cli_download_weights, name="download-weights")
main.add_command(cli_download_metadata, name="download-metadata")
main.add_command(cli_download, name="download")
main.add_command(cli_query_cell, name="query-cell")
main.add_command(cli_attributions, name="attributions")
Expand Down
58 changes: 30 additions & 28 deletions src/decima/cli/attributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import click
from decima.cli.callback import parse_genes, parse_model, parse_attributions
from decima.interpret.attributions import (
plot_attributions,
predict_save_attributions,
Expand Down Expand Up @@ -47,6 +48,7 @@
type=str,
required=False,
default=0,
callback=parse_model,
help="Model to use for attribution analysis either replicate number or path to the model.",
show_default=True,
)
Expand Down Expand Up @@ -87,6 +89,7 @@
type=str,
required=False,
help="Comma-separated list of gene symbols or IDs to analyze.",
callback=parse_genes,
show_default=True,
)
@click.option(
Expand Down Expand Up @@ -149,16 +152,6 @@ def cli_attributions_predict(

└── {output_prefix}.attributions.bigwig # Genome browser track of attribution as bigwig file obtained with averaging the attribution scores across the genes for genomics coordinates.
"""

if model in ["0", "1", "2", "3"]: # replicate index
model = int(model)

if isinstance(device, str) and device.isdigit():
device = int(device)

if genes is not None:
genes = genes.split(",")

predict_save_attributions(
output_prefix=output_prefix,
tasks=tasks,
Expand All @@ -181,7 +174,14 @@ def cli_attributions_predict(

@click.command()
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files")
@click.option("-g", "--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.")
@click.option(
"-g",
"--genes",
type=str,
required=False,
callback=parse_genes,
help="Comma-separated list of gene symbols or IDs to analyze.",
)
@click.option("--seqs", type=str, required=False, help="Path to a file containing sequences to analyze")
@click.option(
"--tasks",
Expand All @@ -197,6 +197,7 @@ def cli_attributions_predict(
type=str,
required=False,
default="ensemble",
callback=parse_model,
help="Model to use for attribution analysis either replicate number or path to the model.",
show_default=True,
)
Expand Down Expand Up @@ -288,12 +289,6 @@ def cli_attributions(

>>> decima attributions -o output_prefix --seqs tests/data/seqs.fasta --tasks "cell_type == 'classical monocyte'" --device 0
"""
if model in ["0", "1", "2", "3"]: # replicate index
model = int(model)

if isinstance(genes, str):
genes = genes.split(",")

predict_attributions_seqlet_calling(
output_prefix=output_prefix,
genes=genes,
Expand Down Expand Up @@ -321,7 +316,9 @@ def cli_attributions(

@click.command()
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files")
@click.option("--attributions", type=str, required=True, help="Path to the attribution files")
@click.option(
"--attributions", type=str, callback=parse_attributions, required=True, help="Path to the attribution files"
)
@click.option(
"--tasks",
type=str,
Expand All @@ -333,7 +330,13 @@ def cli_attributions(
)
@click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.")
@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
@click.option("--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.")
@click.option(
"--genes",
type=str,
required=False,
callback=parse_genes,
help="Comma-separated list of gene symbols or IDs to analyze.",
)
@click.option(
"--top-n-markers",
type=int,
Expand Down Expand Up @@ -393,12 +396,6 @@ def cli_attributions_recursive_seqlet_calling(

>>> decima attributions-recursive-seqlet-calling --attributions attributions_0.h5,attributions_1.h5 -o output_prefix --genes SPI1
"""
if isinstance(attributions, str):
attributions = attributions.split(",")

if genes is not None:
genes = genes.split(",")

recursive_seqlet_calling(
output_prefix=output_prefix,
attributions=attributions,
Expand All @@ -422,7 +419,14 @@ def cli_attributions_recursive_seqlet_calling(

@click.command()
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files")
@click.option("-g", "--genes", type=str, required=False, help="Comma-separated list of gene symbols or IDs to analyze.")
@click.option(
"-g",
"--genes",
type=str,
required=False,
callback=parse_genes,
help="Comma-separated list of gene symbols or IDs to analyze.",
)
@click.option("--metadata", type=click.Path(exists=True), default=None, help="Path to the metadata anndata file.")
@click.option("--tss-distance", type=int, required=False, default=None, help="TSS distance for attribution analysis.")
@click.option("--seqlogo-window", type=int, default=50, help="Window size for sequence logo plots")
Expand All @@ -449,8 +453,6 @@ def cli_attributions_plot(

>>> decima attributions-plot -o output_prefix -g SPI1
"""
genes = genes.split(",")

plot_attributions(
output_prefix=output_prefix,
genes=genes,
Expand Down
57 changes: 57 additions & 0 deletions src/decima/cli/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import click
from pathlib import Path


def parse_model(ctx, param, value):
if value is None:
return None
elif isinstance(value, str):
if value == "ensemble":
return "ensemble"
elif value in ["0", "1", "2", "3"]:
return int(value)

paths = value.split(",")
for path in paths:
if not Path(path).exists():
raise click.ClickException(
f"Model path {path} does not exist. Check if the path is correct and the file exists."
)
return paths

return value


def parse_genes(ctx, param, value):
if value is None:
return None
elif isinstance(value, str):
return value.split(",")
raise ValueError(f"Invalid genes: {value}. Genes should be a comma-separated list of gene names or None.")


def validate_save_replicates(ctx, param, value):
if value:
if ctx.params["model"] == "ensemble":
return value
elif isinstance(ctx.params["model"], list) and (len(ctx.params["model"]) > 1):
return value
else:
raise ValueError(
"`--save-replicates` is only supported for ensemble models. Pass `ensemble` or list of models as the model argument."
)
return value


def parse_attributions(ctx, param, value):
value = value.split(",")
for i in value:
if not Path(i).exists():
raise click.ClickException(
f"Attribution path {i} does not exist. Check if the path is correct and the file exists."
)
elif not i.endswith(".h5"):
raise click.ClickException(
f"Attribution path {i} is not a h5 file. Check if the path is correct and the file is a h5 file."
)
return value
52 changes: 47 additions & 5 deletions src/decima/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,56 @@
`decima download` is the main command for downloading the required data and model weights.

It includes subcommands for:
- Downloading the required data and model weights. `download`
- Caching the required data and model weights. `cache`
"""

import click
from decima.hub.download import download_decima_data
from decima.cli.callback import parse_model
from decima.hub.download import (
cache_decima_data,
download_decima_weights,
download_decima_metadata,
download_decima,
)


@click.command()
def cli_download():
"""Download all required data and model weights."""
download_decima_data()
def cli_cache():
"""Cache all required data and model weights."""
cache_decima_data()


@click.command()
@click.option(
"--model", type=str, default="ensemble", help="Model to download. Default: ensemble.", callback=parse_model
)
@click.option(
"--download-dir",
type=click.Path(),
default=".",
help="Directory to download the model weights. Default: current directory.",
)
def cli_download_weights(model, download_dir):
"""Download pre-trained Decima model weights."""
download_decima_weights(model, str(download_dir))


@click.command()
@click.option(
"--download-dir",
type=click.Path(),
default=".",
help="Directory to download the metadata. Default: current directory.",
)
def cli_download_metadata(download_dir):
"""Download pre-trained Decima metadata."""
download_decima_metadata(str(download_dir))


@click.command()
@click.option(
"--download-dir", type=click.Path(), default=".", help="Directory to download the data. Default: current directory."
)
def cli_download(download_dir):
"""Download model weights and metadata for Decima."""
download_decima(str(download_dir))
26 changes: 22 additions & 4 deletions src/decima/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,27 @@ def cli_finetune(
num_workers,
seed,
):
"""Finetune the Decima model."""
"""Finetune the Decima model.

Args:
name: Name of the run for logging and checkpointing
model: Model path or replication number (0-3)
device: Device to use for training. Default: "0"
matrix_file: Path to the matrix file containing training data
h5_file: Path to the H5 file containing sequences
outdir: Output directory path to save model checkpoints
learning_rate: Learning rate for training. Default: 0.001
loss_total_weight: Total weight parameter for the loss function
gradient_accumulation: Number of gradient accumulation steps
batch_size: Batch size for training. Default: 1
max_seq_shift: Maximum sequence shift for data augmentation. Default: 5000
gradient_clipping: Gradient clipping value. Default: 0.0 (disabled)
save_top_k: Number of best checkpoints to save. Default: 1
epochs: Number of training epochs. Default: 1
logger: Logger type to use. Default: "wandb"
num_workers: Number of data loading workers. Default: 16
seed: Random seed for reproducibility. Default: 0
"""
train_logger = logger
logger = logging.getLogger("decima")
logger.info(f"Data paths: matrix_file={matrix_file}, h5_file={h5_file}")
Expand All @@ -86,7 +106,6 @@ def cli_finetune(
device = int(device)

train_params = {
"name": name,
"batch_size": batch_size,
"num_workers": num_workers,
"devices": device,
Expand All @@ -97,7 +116,6 @@ def cli_finetune(
"total_weight": loss_total_weight,
"accumulate_grad_batches": gradient_accumulation,
"loss": "poisson_multinomial",
# "pairs": ad.uns["disease_pairs"].values,
"clip": gradient_clipping,
"save_top_k": save_top_k,
"pin_memory": True,
Expand All @@ -111,7 +129,7 @@ def cli_finetune(
logger.info(f"model_params: {model_params}")

logger.info("Initializing model")
model = LightningModel(model_params=model_params, train_params=train_params)
model = LightningModel(name=name, model_params=model_params, train_params=train_params)

if train_logger == "wandb":
logger.info("Connecting to wandb.")
Expand Down
Loading