diff --git a/FrEIA/distributions/normal.py b/FrEIA/distributions/normal.py index a0d075d..7e03b89 100644 --- a/FrEIA/distributions/normal.py +++ b/FrEIA/distributions/normal.py @@ -3,8 +3,12 @@ class StandardNormalDistribution(Independent): - def __init__(self, *event_shape: int, device=None, dtype=None): + def __init__(self, *event_shape: int, device=None, dtype=None, validate_args=True): loc = torch.tensor(0., device=device, dtype=dtype).repeat(event_shape) scale = torch.tensor(1., device=device, dtype=dtype).repeat(event_shape) - super().__init__(Normal(loc, scale), len(event_shape)) + super().__init__( + Normal(loc, scale, validate_args=validate_args), + len(event_shape), + validate_args=validate_args + ) diff --git a/FrEIA/modules/__init__.py b/FrEIA/modules/__init__.py index b1f46d7..32b1614 100644 --- a/FrEIA/modules/__init__.py +++ b/FrEIA/modules/__init__.py @@ -16,6 +16,7 @@ * GINCouplingBlock * AffineCouplingOneSided * ConditionalAffineTransform +* RationalQuadraticSpline Reshaping: @@ -43,6 +44,7 @@ * LearnedElementwiseScaling * OrthogonalTransform * HouseholderPerm +* ElementwiseRationalQuadraticSpline Fixed (non-learned) transforms: @@ -106,4 +108,5 @@ 'GaussianMixtureModel', 'LinearSpline', 'RationalQuadraticSpline', + 'ElementwiseRationalQuadraticSpline', ] diff --git a/FrEIA/modules/graph_topology.py b/FrEIA/modules/graph_topology.py index f6256a7..bec6c9e 100644 --- a/FrEIA/modules/graph_topology.py +++ b/FrEIA/modules/graph_topology.py @@ -61,13 +61,13 @@ def __init__(self, else: if isinstance(section_sizes, int): assert section_sizes < l_dim, "'section_sizes' too large" - else: - assert isinstance(section_sizes, (list, tuple)), \ - "'section_sizes' must be either int or list/tuple of int" - assert sum(section_sizes) <= l_dim, "'section_sizes' too large" - if sum(section_sizes) < l_dim: - warnings.warn("'section_sizes' too small, adding additional section") - section_sizes = list(section_sizes).append(l_dim - sum(section_sizes)) + section_sizes = (section_sizes,) + assert isinstance(section_sizes, (list, tuple)), \ + "'section_sizes' must be either int or list/tuple of int" + assert sum(section_sizes) <= l_dim, "'section_sizes' too large" + if sum(section_sizes) < l_dim: + warnings.warn("'section_sizes' too small, adding additional section") + section_sizes = list(section_sizes) + [l_dim - sum(section_sizes)] self.split_size_or_sections = section_sizes def forward(self, x, rev=False, jac=True): diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index c9c623e..759d693 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -7,19 +7,63 @@ from itertools import chain from FrEIA.modules.coupling_layers import _BaseCouplingBlock +from FrEIA.modules.base import InvertibleModule from FrEIA import utils - class BinnedSpline(_BaseCouplingBlock): + def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, + split_len: Union[float, int] = 0.5, **kwargs) -> None: + if dims_c is None: + dims_c = [] + + super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len) + + + self.spline_base = BinnedSplineBase(dims_in, dims_c, **kwargs) + + num_params = sum(self.spline_base.parameter_counts.values()) + self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params) + self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params) + + def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """ + The full coupling consists of: + 1. Querying the parameter tensor from the subnetwork + 2. Splitting this tensor into the semantic parameters + 3. Constraining the parameters + 4. Performing the actual spline for each bin, given the parameters + """ + parameters = self.subnet1(u2) + parameters = self.spline_base.split_parameters(parameters, self.split_len1) + parameters = self.constrain_parameters(parameters) + + return self.spline_base.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev) + + def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + parameters = self.subnet2(u1) + parameters = self.spline_base.split_parameters(parameters, self.split_len2) + parameters = self.constrain_parameters(parameters) + + return self.spline_base.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev) + + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self.spline_base.constrain_parameters(parameters) + +class BinnedSplineBase(InvertibleModule): """ Base Class for Splines Implements input-binning, where bin knots are jointly predicted along with spline parameters by a non-invertible coupling subnetwork """ - def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, split_len: Union[float, int] = 0.5, - bins: int = 10, parameter_counts: Dict[str, int] = None, min_bin_sizes: Tuple[float] = (0.1, 0.1), - default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None: + def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None, + min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None: """ Args: bins: number of bins to use @@ -35,7 +79,7 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp if parameter_counts is None: parameter_counts = {} - super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len) + super().__init__(dims_in, dims_c) assert bins >= 1, "need at least one bin" assert all(s >= 0 for s in min_bin_sizes), "minimum bin size cannot be negative" @@ -61,40 +105,9 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp # merge parameter counts with child classes self.parameter_counts = {**default_parameter_counts, **parameter_counts} - num_params = sum(self.parameter_counts.values()) - self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params) - self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params) - - def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - """ - The full coupling consists of: - 1. Querying the parameter tensor from the subnetwork - 2. Splitting this tensor into the semantic parameters - 3. Constraining the parameters - 4. Performing the actual spline for each bin, given the parameters - """ - parameters = self.subnet1(u2) - parameters = self.split_parameters(parameters, self.split_len1) - parameters = self.constrain_parameters(parameters) - - return self.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev) - - def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - parameters = self.subnet2(u1) - parameters = self.split_parameters(parameters, self.split_len2) - parameters = self.constrain_parameters(parameters) - - return self.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev) - def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str, torch.Tensor]: """ - Split network output into semantic parameters, as given by self.parameter_counts + Split parameter tensor into semantic parameters, as given by self.parameter_counts """ parameters = parameters.movedim(1, -1) parameters = parameters.reshape(*parameters.shape[:-1], split_len, -1) diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index c296375..6656fb8 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -1,17 +1,18 @@ -from typing import Dict, List, Tuple +from typing import Dict, Tuple, Callable, Iterable, List import torch +from torch import nn import torch.nn.functional as F import numpy as np -from .binned import BinnedSpline +from .binned import BinnedSplineBase, BinnedSpline class RationalQuadraticSpline(BinnedSpline): def __init__(self, *args, bins: int = 10, **kwargs): # parameter constraints count # 1. the derivative at the edge of each inner bin positive #bins - 1 - super().__init__(*args, **kwargs, bins=bins, parameter_counts={"deltas": bins - 1}) + super().__init__(*args, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs) def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: parameters = super().constrain_parameters(parameters) @@ -44,6 +45,65 @@ def _spline2(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bo return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev) +class ElementwiseRationalQuadraticSpline(BinnedSplineBase): + def __init__(self, dims_in, dims_c=[], subnet_constructor: Callable = None, + bins: int = 10, **kwargs) -> None: + super().__init__(dims_in, dims_c, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs) + + self.channels = dims_in[0][0] + self.condition_length = sum([dims_c[i][0] for i in range(len(dims_c))]) + self.conditional = len(dims_c) > 0 + + num_params = sum(self.parameter_counts.values()) + + + if self.conditional: + self.subnet = subnet_constructor(self.condition_length, self.channels * num_params) + else: + self.spline_parameters = nn.Parameter(torch.zeros(self.channels * num_params, *dims_in[0][1:])) + + + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + parameters = super().constrain_parameters(parameters) + # we additionally want positive derivatives to preserve monotonicity + # the derivative must also match the tails at the spline boundaries + deltas = parameters["deltas"] + + # shifted softplus such that network output 0 -> delta = scale + shift = np.log(np.e - 1) + deltas = F.softplus(deltas + shift) + + # boundary condition: derivative is equal to affine scale at spline boundaries + scale = torch.sum(parameters["heights"], dim=-1, keepdim=True) / torch.sum(parameters["widths"], dim=-1, keepdim=True) + scale = scale.expand(*scale.shape[:-1], 2) + + deltas = torch.cat((deltas, scale), dim=-1).roll(1, dims=-1) + + parameters["deltas"] = deltas + + return parameters + + def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: + return input_dims + + def forward(self, x_or_z: Iterable[torch.Tensor], c: Iterable[torch.Tensor] = None, + rev: bool = False, jac: bool = True) \ + -> Tuple[Tuple[torch.Tensor], torch.Tensor]: + if self.conditional: + parameters = self.subnet(torch.cat(c, dim=1).float()) + else: + parameters = self.spline_parameters.unsqueeze(0).repeat_interleave(x_or_z[0].shape[0], dim=0) + parameters = self.split_parameters(parameters, self.channels) + parameters = self.constrain_parameters(parameters) + + y, jac = self.binned_spline(x=x_or_z[0], parameters=parameters, spline=self.spline, rev=rev) + return (y,), jac + + def spline(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + left, right, bottom, top = parameters["left"], parameters["right"], parameters["bottom"], parameters["top"] + deltas_left, deltas_right = parameters["deltas_left"], parameters["deltas_right"] + return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev) + def rational_quadratic_spline(x: torch.Tensor, left: torch.Tensor, right: torch.Tensor, @@ -116,3 +176,5 @@ def rational_quadratic_spline(x: torch.Tensor, log_jac = torch.log(numerator) - torch.log(denominator) return out, log_jac + + diff --git a/docs/_build/html/FrEIA.framework.html b/docs/_build/html/FrEIA.framework.html index 880670d..dc31bdd 100644 --- a/docs/_build/html/FrEIA.framework.html +++ b/docs/_build/html/FrEIA.framework.html @@ -412,6 +412,20 @@
Approximate log Jacobian determinant via finite differences.
+Generates a plot of the GraphINN and stores it as pdf and dot file
+path – Directory to store the plots in. Must exist previous to plotting
filename – Name of the newly generated plots
GraphINN.get_module_by_name()
GraphINN.get_node_by_name()GraphINN.log_jacobian_numerical()GraphINN.plot()InputNode
Reshaping:
LearnedElementwiseScaling
OrthogonalTransform
HouseholderPerm
ElementwiseRationalQuadraticSpline
Fixed (non-learned) transforms:
GraphINN.get_module_by_name()GraphINN.get_node_by_name()GraphINN.log_jacobian_numerical()GraphINN.plot()InputNode