forked from maxxxzdn/eegVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
1,402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Learning Generative Factors of Neuroimaging Data with Variational auto-encoders | ||
|
||
Official implementation of: | ||
|
||
**Learning Generative Factors of Neuroimaging Data with Variational auto-encoders** | ||
Maksim Zhdanov*, Saskia Steinmann, Nico Hoffmann | ||
|
||
<img src="model_scheme.png" width="800"> | ||
|
||
**Abstract**: Neuroimaging techniques produce high-dimensional, stochastic data from which it might be challenging to extract high-level knowledge about the phenomena of interest. We address this challenge by applying the framework of generative modelling to 1) classify multiple pathologies, 2) recover neurological mechanisms of those pathologies in a data-driven manner and 3) learn robust representations of neuroimaging data. We illustrate the applicability of the proposed approach to identifying schizophrenia, either followed or not by auditory verbal hallucinations. We further demonstrate the ability of the framework to learn disease-related mechanisms that are consistent with current domain knowledge. We also compare the proposed framework with several benchmark approaches and indicate its advantages. | ||
} | ||
|
||
### Requirements | ||
* PyTorch 1.7.1 | ||
* mne 0.23 (for visualization) |
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.distributions as dist | ||
from ..fc_networks import (FCCondPrior_v2, FCEncoder, FCDecoder) | ||
|
||
|
||
class CVAE(nn.Module): | ||
""" | ||
Conditional VAE | ||
""" | ||
def __init__(self, z_dim, num_classes, | ||
use_cuda, mode, p_dropout, beta): | ||
super(CVAE, self).__init__() | ||
self.z_dim = z_dim | ||
self.num_classes = num_classes | ||
self.mode = mode | ||
self.beta = beta | ||
|
||
if mode == 'FC': | ||
self.encoder = FCEncoder(z_dim, p_dropout) | ||
self.decoder = FCDecoder(z_dim, p_dropout) | ||
self.cond_prior = FCCondPrior_v2(z_dim, self.num_classes) | ||
self.classifier = nn.Sequential( | ||
nn.Linear(z_dim, 16), | ||
nn.ReLU(), | ||
nn.Linear(16, num_classes), | ||
nn.Sigmoid(), | ||
) | ||
else: | ||
raise NotImplementedError("not implemented") | ||
|
||
if use_cuda: | ||
self.cuda() | ||
|
||
def elbo(self, x, y): | ||
""" | ||
Computes ELBO. | ||
""" | ||
bs = x.shape[0] | ||
# inference | ||
post_params = self.encoder(x) | ||
zc = dist.Normal(*post_params).rsample() | ||
qyzc = dist.Bernoulli(probs=self.classifier(zc)) | ||
log_qyzc = qyzc.log_prob(y).sum(dim=-1) | ||
|
||
# compute kl | ||
prior_params = self.cond_prior(y) | ||
kl = self.beta*compute_kl(*post_params, *prior_params) | ||
|
||
# compute log probs for x and y | ||
recon = self.decoder(zc) | ||
log_qyx = self.classifier_loss(x, y) | ||
log_pxz = self.img_log_likelihood(recon, x) | ||
|
||
# compute gradients only wrt to params of qyz, no propogating to qzx | ||
# see https://arxiv.org/abs/2006.10102 Appendix C.3.1 | ||
log_qyzc_ = dist.Bernoulli(probs=self.classifier(zc.detach())).log_prob(y).sum(dim=-1) | ||
w = torch.exp(log_qyzc_ - log_qyx) | ||
if self.mode == 'FC': | ||
elbo = (w * (log_pxz - kl - log_qyzc) + log_qyx).mean() | ||
else: | ||
elbo = (w * ((log_pxz - kl).mean(-1) - log_qyzc) + log_qyx).mean() | ||
return -elbo | ||
|
||
def classifier_loss(self, x, y, k=100): | ||
""" | ||
Computes the classifier loss. | ||
""" | ||
post_params = self.encoder(x) | ||
zc = dist.Normal(*post_params).rsample(torch.tensor([k])) | ||
if self.mode == 'FC': | ||
probs = self.classifier(zc.view(-1, self.z_dim)) | ||
else: | ||
probs = self.classifier(zc.view(-1, 61, self.z_dim)) | ||
d = dist.Bernoulli(probs=probs) | ||
y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) | ||
lqy_z = d.log_prob(y).view(k, x.shape[0], self.num_classes).sum(dim=-1) | ||
lqy_x = torch.logsumexp(lqy_z, dim=0) - np.log(k) | ||
return lqy_x | ||
|
||
def reconstruct_img(self, x): | ||
""" | ||
Computes μ of p(x|z). | ||
""" | ||
return self.decoder(dist.Normal(*self.encoder(x)).sample())[0] | ||
|
||
def classifier_acc(self, x, y=None, k=1): | ||
""" | ||
Computes accuracy of classification. | ||
""" | ||
post_params = self.encoder(x) | ||
zc = dist.Normal(*post_params).rsample(torch.tensor([k])) | ||
if self.mode == 'EEG': | ||
probs = self.classifier(zc.view(-1, 61, self.z_dim)) | ||
else: | ||
probs = self.classifier(zc.view(-1, self.z_dim)) | ||
y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) | ||
preds = torch.round(probs) | ||
acc = (preds.eq(y)).float().mean(0) | ||
return acc | ||
|
||
def img_log_likelihood(self, recon, xs): | ||
""" | ||
Computes log p(x|z). | ||
""" | ||
if self.mode == 'EEG': | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1)) | ||
else: | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1,-2)) | ||
|
||
def compute_kl(locs_q, scale_q, locs_p=None, scale_p=None): | ||
""" | ||
Computes the KL(q||p) | ||
""" | ||
if locs_p is None: | ||
locs_p = torch.zeros_like(locs_q) | ||
if scale_p is None: | ||
scale_p = torch.ones_like(scale_q) | ||
|
||
dist_q = dist.Normal(locs_q, scale_q) | ||
dist_p = dist.Normal(locs_p, scale_p) | ||
return dist.kl.kl_divergence(dist_q, dist_p).sum(dim=-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.distributions as dist | ||
from ..fc_networks import (FCEncoder_v2, FCDecoder) | ||
|
||
|
||
class VAE(nn.Module): | ||
""" | ||
VAE + classification | ||
""" | ||
def __init__(self, z_dim, num_classes, | ||
use_cuda, mode, p_dropout, beta): | ||
super(VAE, self).__init__() | ||
self.z_dim = z_dim | ||
self.z_classify = num_classes | ||
self.z_style = z_dim - num_classes | ||
self.num_classes = num_classes | ||
self.mode = mode | ||
self.beta = beta | ||
|
||
if mode == 'FC': | ||
self.encoder = FCEncoder_v2(z_dim, p_dropout) | ||
self.decoder = FCDecoder(z_dim, p_dropout) | ||
self.classifier = nn.Sequential( | ||
nn.Linear(z_dim, z_dim//2), | ||
nn.ReLU(), | ||
nn.Linear(z_dim//2, self.num_classes), | ||
nn.Sigmoid(), | ||
) | ||
else: | ||
raise NotImplementedError("not implemented") | ||
|
||
if use_cuda: | ||
self.cuda() | ||
|
||
def elbo(self, x, y): | ||
""" | ||
Computes ELBO + classification loss. | ||
""" | ||
#μ, σ of q(z|x) | ||
post_params = self.encoder(x) | ||
z = dist.Normal(*post_params).rsample() | ||
#KL(q(z|x) || p(z)) | ||
kl = compute_kl(*post_params) | ||
#x ~ p(x|z) | ||
recon = self.decoder(z) | ||
#ELBO | ||
log_pxz = self.img_log_likelihood(recon, x) | ||
elbo = log_pxz - self.beta*kl | ||
#ELBO + classification | ||
y_pred = self.classifier(z) | ||
loss = -elbo + F.binary_cross_entropy(y_pred,y) | ||
return loss.mean() | ||
|
||
def reconstruct_img(self, x): | ||
""" | ||
Computes μ of p(x|z). | ||
""" | ||
return self.decoder(dist.Normal(*self.encoder(x)).sample())[0] | ||
|
||
def classifier_acc(self, x, y=None, k=1): | ||
""" | ||
Computes accuracy of classification. | ||
""" | ||
post_params = self.encoder(x) | ||
z = dist.Normal(*post_params).rsample() | ||
probs = self.classifier(z) | ||
preds = torch.round(probs) | ||
acc = (preds.eq(y)).float().mean(0) | ||
return acc | ||
|
||
def img_log_likelihood(self, recon, xs): | ||
""" | ||
Computes log p(x|z). | ||
""" | ||
if self.mode == 'FC': | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1,-2)) | ||
else: | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1)) | ||
|
||
def compute_kl(locs_q, scale_q, locs_p=None, scale_p=None): | ||
""" | ||
Computes the KL(q||p) | ||
""" | ||
if locs_p is None: | ||
locs_p = torch.zeros_like(locs_q) | ||
if scale_p is None: | ||
scale_p = torch.ones_like(scale_q) | ||
|
||
dist_q = dist.Normal(locs_q, scale_q) | ||
dist_p = dist.Normal(locs_p, scale_p) | ||
return dist.kl.kl_divergence(dist_q, dist_p).sum(dim=-1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.distributions as dist | ||
from .fc_networks import (FCEncoder, FCDecoder, FCClassifier, FCCondPrior) | ||
|
||
|
||
class CCVAE(nn.Module): | ||
""" | ||
CCVAE | ||
see https://arxiv.org/abs/2006.10102 for details | ||
""" | ||
def __init__(self, z_dim, num_classes, | ||
use_cuda, mode, p_dropout, beta): | ||
super(CCVAE, self).__init__() | ||
self.z_dim = z_dim | ||
self.z_classify = num_classes | ||
self.z_style = z_dim - num_classes | ||
self.num_classes = num_classes | ||
self.ones = torch.ones(1, self.z_style) | ||
self.zeros = torch.zeros(1, self.z_style) | ||
self.mode = mode | ||
self.beta = beta | ||
|
||
if mode == 'FC': | ||
self.encoder = FCEncoder(z_dim, p_dropout) | ||
self.decoder = FCDecoder(z_dim, p_dropout) | ||
self.classifier = FCClassifier(self.num_classes) | ||
self.cond_prior = FCCondPrior(self.num_classes) | ||
else: | ||
raise NotImplementedError("not implemented") | ||
|
||
if use_cuda: | ||
self.ones = self.ones.cuda() | ||
self.zeros = self.zeros.cuda() | ||
self.cuda() | ||
|
||
def elbo(self, x, y): | ||
""" | ||
Computes ELBO. | ||
""" | ||
bs = x.shape[0] | ||
# inference | ||
post_params = self.encoder(x) | ||
z = dist.Normal(*post_params).rsample() | ||
zc, zs = z.split([self.z_classify, self.z_style], -1) | ||
qyzc = dist.Bernoulli(probs=self.classifier(zc)) | ||
log_qyzc = qyzc.log_prob(y).sum(dim=-1) | ||
|
||
# compute kl | ||
locs_p_zc, scales_p_zc = self.cond_prior(y) | ||
if self.mode == 'EEG': | ||
prior_params = (torch.cat([locs_p_zc, self.zeros.expand(bs, 61, -1)], dim=-1), | ||
torch.cat([scales_p_zc, self.ones.expand(bs, 61, -1)], dim=-1)) | ||
else: | ||
prior_params = (torch.cat([locs_p_zc, self.zeros.expand(bs, -1)], dim=-1), | ||
torch.cat([scales_p_zc, self.ones.expand(bs, -1)], dim=-1)) | ||
kl = self.beta*compute_kl(*post_params, *prior_params) | ||
|
||
# compute log probs for x and y | ||
recon = self.decoder(z) | ||
log_qyx = self.classifier_loss(x, y) | ||
log_pxz = self.img_log_likelihood(recon, x) | ||
|
||
# compute gradients only wrt to params of qyz, no propogating to qzx | ||
# see https://arxiv.org/abs/2006.10102 Appendix C.3.1 | ||
log_qyzc_ = dist.Bernoulli(probs=self.classifier(zc.detach())).log_prob(y).sum(dim=-1) | ||
w = torch.exp(log_qyzc_ - log_qyx) | ||
if self.mode == 'FC': | ||
elbo = (w * (log_pxz - kl - log_qyzc) + log_qyx).mean() | ||
else: | ||
elbo = (w * ((log_pxz - kl).mean(-1) - log_qyzc) + log_qyx).mean() | ||
return -elbo | ||
|
||
def classifier_loss(self, x, y, k=100): | ||
""" | ||
Computes the classifier loss. | ||
""" | ||
post_params = self.encoder(x) | ||
zc, _ = dist.Normal(*post_params).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1) | ||
if self.mode == 'FC': | ||
probs = self.classifier(zc.view(-1, self.z_classify)) | ||
else: | ||
probs = self.classifier(zc.view(-1, 61, self.z_classify)) | ||
d = dist.Bernoulli(probs=probs) | ||
y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) | ||
lqy_z = d.log_prob(y).view(k, x.shape[0], self.num_classes).sum(dim=-1) | ||
lqy_x = torch.logsumexp(lqy_z, dim=0) - np.log(k) | ||
return lqy_x | ||
|
||
def reconstruct_img(self, x): | ||
""" | ||
Computes μ of p(x|z). | ||
""" | ||
return self.decoder(dist.Normal(*self.encoder(x)).sample())[0] | ||
|
||
def classifier_acc(self, x, y=None, k=1): | ||
""" | ||
Computes accuracy of classification. | ||
""" | ||
post_params = self.encoder(x) | ||
zc, _ = dist.Normal(*post_params).rsample(torch.tensor([k])).split([self.z_classify, self.z_style], -1) | ||
if self.mode == 'FC': | ||
probs = self.classifier(zc.view(-1, self.z_classify)) | ||
else: | ||
probs = self.classifier(zc.view(-1, 61, self.z_classify)) | ||
y = y.expand(k, -1, -1).contiguous().view(-1, self.num_classes) | ||
preds = torch.round(probs) | ||
acc = (preds.eq(y)).float().mean(0) | ||
return acc | ||
|
||
def img_log_likelihood(self, recon, xs): | ||
""" | ||
Computes log p(x|z). | ||
""" | ||
if self.mode == 'FC': | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1,-2)) | ||
else: | ||
return dist.Normal(*recon).log_prob(xs).sum(dim=(-1)) | ||
|
||
def compute_kl(locs_q, scale_q, locs_p=None, scale_p=None): | ||
""" | ||
Computes the KL(q||p) | ||
""" | ||
if locs_p is None: | ||
locs_p = torch.zeros_like(locs_q) | ||
if scale_p is None: | ||
scale_p = torch.ones_like(scale_q) | ||
|
||
dist_q = dist.Normal(locs_q, scale_q) | ||
dist_p = dist.Normal(locs_p, scale_p) | ||
return dist.kl.kl_divergence(dist_q, dist_p).sum(dim=-1) |
Oops, something went wrong.