From 1b8c8d1f8e3e0210fdf1addcadf6455dd2b019b5 Mon Sep 17 00:00:00 2001 From: Mariia Minaeva Date: Sun, 27 Jul 2025 17:25:21 +0200 Subject: [PATCH 1/2] enabled several nb and poisson losses --- src/multigrate/model/_multivae.py | 3 ++- src/multigrate/module/_multivae_torch.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index 58a136c..d5b49ca 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -291,6 +291,7 @@ def train( weight_decay: float = 1e-3, eps: float = 1e-08, early_stopping: bool = True, + early_stopping_patience = 50, save_best: bool = True, check_val_every_n_epoch: int | None = None, n_epochs_kl_warmup: int | None = None, @@ -419,7 +420,7 @@ def train( early_stopping=early_stopping, check_val_every_n_epoch=check_val_every_n_epoch, early_stopping_monitor="reconstruction_loss_validation", - early_stopping_patience=50, + early_stopping_patience=early_stopping_patience, enable_checkpointing=True, **kwargs, ) diff --git a/src/multigrate/module/_multivae_torch.py b/src/multigrate/module/_multivae_torch.py index c91a7ee..e1d888d 100644 --- a/src/multigrate/module/_multivae_torch.py +++ b/src/multigrate/module/_multivae_torch.py @@ -181,11 +181,13 @@ def __init__( # assume for now that can only use nb/zinb once, i.e. for RNA-seq modality # TODO: add check for multiple nb/zinb losses given - self.theta = None + self.theta = [] + j = [] for i, loss in enumerate(losses): if loss in ["nb", "zinb"]: - self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups)) - break + self.theta.append(torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups))) + else: + self.theta.append([]) # modality encoders cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0 @@ -307,6 +309,7 @@ def _h_to_x(self, h, i): return x def _product_of_experts(self, mus, logvars, masks): + #print(mus, logvars, masks) vars = torch.exp(logvars) masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1]) mus_joint = torch.sum(mus * masks / vars, dim=1) @@ -657,7 +660,7 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks dec_mean = r size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1)) dec_mean = dec_mean * size_factor_view - dispersion = self.theta.T[group.squeeze().long()] + dispersion = self.theta[i].to(self.device).T[group.squeeze().long()] dispersion = torch.exp(dispersion) nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1) nb_loss = loss_coefs[str(i)] * nb_loss @@ -666,9 +669,9 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks dec_mean, dec_dropout = r dec_mean = dec_mean.squeeze() dec_dropout = dec_dropout.squeeze() - size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1)) + size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1)) dec_mean = dec_mean * size_factor_view - dispersion = self.theta.T[group.squeeze().long()] + dispersion = self.theta[i].to(self.device).T[group.squeeze().long()] dispersion = torch.exp(dispersion) zinb_loss = torch.sum( ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x), From 02e83324a80c5188d3e4cfe171dc01d3d48fb5ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 27 Jul 2025 15:33:44 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 6 +++--- docs/contributing.md | 18 +++++++++--------- src/multigrate/model/_multivae.py | 2 +- src/multigrate/module/_multivae_torch.py | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 441d48c..5962a36 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,12 @@ Please refer to the [documentation][link-docs]. In particular, the -- [API documentation][link-api] +- [API documentation][link-api] and the tutorials: -- [Paired integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/paired_integration_cite-seq.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/paired_integration_cite-seq.ipynb) -- [Trimodal integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/trimodal_integration.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/trimodal_integration.ipynb) +- [Paired integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/paired_integration_cite-seq.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/paired_integration_cite-seq.ipynb) +- [Trimodal integration and query-to-reference mapping](https://multigrate.readthedocs.io/en/latest/notebooks/trimodal_integration.html) [![Open In Colab][badge-colab]](https://colab.research.google.com/github/theislab/multigrate/blob/main/docs/notebooks/trimodal_integration.ipynb) ## Installation diff --git a/docs/contributing.md b/docs/contributing.md index c8c6c49..7f31443 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -99,11 +99,11 @@ Specify `vX.X.X` as a tag name and create a release. For more information, see [ Please write documentation for new or changed features and use-cases. This project uses [sphinx][] with the following features: -- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text -- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). -- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) -- [Sphinx autodoc typehints][], to automatically reference annotated input and output types -- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) +- the [myst][] extension allows to write documentation in markdown/Markedly Structured Text +- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). +- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) +- [Sphinx autodoc typehints][], to automatically reference annotated input and output types +- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information on how to write documentation. @@ -120,10 +120,10 @@ repository. #### Hints -- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only - if you do so can sphinx automatically create a link to the external documentation. -- If building the documentation fails because of a missing link that is outside your control, you can add an entry to - the `nitpick_ignore` list in `docs/conf.py` +- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. Only + if you do so can sphinx automatically create a link to the external documentation. +- If building the documentation fails because of a missing link that is outside your control, you can add an entry to + the `nitpick_ignore` list in `docs/conf.py` #### Building the docs locally diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index d5b49ca..39d3616 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -291,7 +291,7 @@ def train( weight_decay: float = 1e-3, eps: float = 1e-08, early_stopping: bool = True, - early_stopping_patience = 50, + early_stopping_patience=50, save_best: bool = True, check_val_every_n_epoch: int | None = None, n_epochs_kl_warmup: int | None = None, diff --git a/src/multigrate/module/_multivae_torch.py b/src/multigrate/module/_multivae_torch.py index e1d888d..fac8715 100644 --- a/src/multigrate/module/_multivae_torch.py +++ b/src/multigrate/module/_multivae_torch.py @@ -309,7 +309,7 @@ def _h_to_x(self, h, i): return x def _product_of_experts(self, mus, logvars, masks): - #print(mus, logvars, masks) + # print(mus, logvars, masks) vars = torch.exp(logvars) masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1]) mus_joint = torch.sum(mus * masks / vars, dim=1)