Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

Please refer to the [documentation][link-docs]. In particular, the

- [API documentation][link-api]
- [API documentation][link-api]

and the tutorials:

- [Paired integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/paired_integration_cite-seq.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/paired_integration_cite-seq.ipynb)
- [Trimodal integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/trimodal_integration.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/trimodal_integration.ipynb)
- [Paired integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/paired_integration_cite-seq.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/paired_integration_cite-seq.ipynb)
- [Trimodal integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/trimodal_integration.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/trimodal_integration.ipynb)

## Installation

Expand Down
18 changes: 9 additions & 9 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ Specify `vX.X.X` as a tag name and create a release. For more information, see [

Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features:

- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)

See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
Expand All @@ -120,10 +120,10 @@ repository.

#### Hints

- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`
- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only
if you do so can sphinx automatically create a link to the external documentation.
- If building the documentation fails because of a missing link that is outside your control, you can add an entry to
the `nitpick_ignore` list in `docs/conf.py`

#### Building the docs locally

Expand Down
3 changes: 2 additions & 1 deletion src/multigrate/model/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def train(
weight_decay: float = 1e-3,
eps: float = 1e-08,
early_stopping: bool = True,
early_stopping_patience=50,
save_best: bool = True,
check_val_every_n_epoch: int | None = None,
n_epochs_kl_warmup: int | None = None,
Expand Down Expand Up @@ -419,7 +420,7 @@ def train(
early_stopping=early_stopping,
check_val_every_n_epoch=check_val_every_n_epoch,
early_stopping_monitor="reconstruction_loss_validation",
early_stopping_patience=50,
early_stopping_patience=early_stopping_patience,
enable_checkpointing=True,
**kwargs,
)
Expand Down
15 changes: 9 additions & 6 deletions src/multigrate/module/_multivae_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ def __init__(

# assume for now that can only use nb/zinb once, i.e. for RNA-seq modality
# TODO: add check for multiple nb/zinb losses given
self.theta = None
self.theta = []
j = []
for i, loss in enumerate(losses):
if loss in ["nb", "zinb"]:
self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups))
break
self.theta.append(torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups)))
else:
self.theta.append([])

# modality encoders
cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0
Expand Down Expand Up @@ -307,6 +309,7 @@ def _h_to_x(self, h, i):
return x

def _product_of_experts(self, mus, logvars, masks):
# print(mus, logvars, masks)
vars = torch.exp(logvars)
masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1])
mus_joint = torch.sum(mus * masks / vars, dim=1)
Expand Down Expand Up @@ -657,7 +660,7 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
dec_mean = r
size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1))
dec_mean = dec_mean * size_factor_view
dispersion = self.theta.T[group.squeeze().long()]
dispersion = self.theta[i].to(self.device).T[group.squeeze().long()]
dispersion = torch.exp(dispersion)
nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1)
nb_loss = loss_coefs[str(i)] * nb_loss
Expand All @@ -666,9 +669,9 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks
dec_mean, dec_dropout = r
dec_mean = dec_mean.squeeze()
dec_dropout = dec_dropout.squeeze()
size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1))
size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1))
dec_mean = dec_mean * size_factor_view
dispersion = self.theta.T[group.squeeze().long()]
dispersion = self.theta[i].to(self.device).T[group.squeeze().long()]
dispersion = torch.exp(dispersion)
zinb_loss = torch.sum(
ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x),
Expand Down