diff --git a/FrEIA/modules/fixed_transforms.py b/FrEIA/modules/fixed_transforms.py index 13b9a9d..ed26063 100644 --- a/FrEIA/modules/fixed_transforms.py +++ b/FrEIA/modules/fixed_transforms.py @@ -1,5 +1,5 @@ from . import InvertibleModule - +from utils import sum_except_batch from typing import Union, Iterable, Tuple import numpy as np @@ -175,8 +175,10 @@ def forward(self, x_or_z: Iterable[torch.Tensor], c: Iterable[torch.Tensor] = No # the following is the diagonal Jacobian as sigmoid is an element-wise op logJ = torch.log(1 / ((1 + torch.exp(_input)) * (1 + torch.exp(-_input)))) # determinant of a log diagonal Jacobian is simply the sum of its diagonals - detLogJ = logJ.sum(1) + detLogJ = sum_except_batch(logJ) + if not rev: return ((result, ), detLogJ) else: return ((result, ), -detLogJ) + \ No newline at end of file