diff --git a/FrEIA/distributions/normal.py b/FrEIA/distributions/normal.py index a0d075d..7e03b89 100644 --- a/FrEIA/distributions/normal.py +++ b/FrEIA/distributions/normal.py @@ -3,8 +3,12 @@ class StandardNormalDistribution(Independent): - def __init__(self, *event_shape: int, device=None, dtype=None): + def __init__(self, *event_shape: int, device=None, dtype=None, validate_args=True): loc = torch.tensor(0., device=device, dtype=dtype).repeat(event_shape) scale = torch.tensor(1., device=device, dtype=dtype).repeat(event_shape) - super().__init__(Normal(loc, scale), len(event_shape)) + super().__init__( + Normal(loc, scale, validate_args=validate_args), + len(event_shape), + validate_args=validate_args + ) diff --git a/FrEIA/modules/__init__.py b/FrEIA/modules/__init__.py index b1f46d7..32b1614 100644 --- a/FrEIA/modules/__init__.py +++ b/FrEIA/modules/__init__.py @@ -16,6 +16,7 @@ * GINCouplingBlock * AffineCouplingOneSided * ConditionalAffineTransform +* RationalQuadraticSpline Reshaping: @@ -43,6 +44,7 @@ * LearnedElementwiseScaling * OrthogonalTransform * HouseholderPerm +* ElementwiseRationalQuadraticSpline Fixed (non-learned) transforms: @@ -106,4 +108,5 @@ 'GaussianMixtureModel', 'LinearSpline', 'RationalQuadraticSpline', + 'ElementwiseRationalQuadraticSpline', ] diff --git a/FrEIA/modules/graph_topology.py b/FrEIA/modules/graph_topology.py index f6256a7..bec6c9e 100644 --- a/FrEIA/modules/graph_topology.py +++ b/FrEIA/modules/graph_topology.py @@ -61,13 +61,13 @@ def __init__(self, else: if isinstance(section_sizes, int): assert section_sizes < l_dim, "'section_sizes' too large" - else: - assert isinstance(section_sizes, (list, tuple)), \ - "'section_sizes' must be either int or list/tuple of int" - assert sum(section_sizes) <= l_dim, "'section_sizes' too large" - if sum(section_sizes) < l_dim: - warnings.warn("'section_sizes' too small, adding additional section") - section_sizes = list(section_sizes).append(l_dim - sum(section_sizes)) + section_sizes = (section_sizes,) + assert isinstance(section_sizes, (list, tuple)), \ + "'section_sizes' must be either int or list/tuple of int" + assert sum(section_sizes) <= l_dim, "'section_sizes' too large" + if sum(section_sizes) < l_dim: + warnings.warn("'section_sizes' too small, adding additional section") + section_sizes = list(section_sizes) + [l_dim - sum(section_sizes)] self.split_size_or_sections = section_sizes def forward(self, x, rev=False, jac=True): diff --git a/FrEIA/modules/splines/binned.py b/FrEIA/modules/splines/binned.py index c9c623e..759d693 100644 --- a/FrEIA/modules/splines/binned.py +++ b/FrEIA/modules/splines/binned.py @@ -7,19 +7,63 @@ from itertools import chain from FrEIA.modules.coupling_layers import _BaseCouplingBlock +from FrEIA.modules.base import InvertibleModule from FrEIA import utils - class BinnedSpline(_BaseCouplingBlock): + def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, + split_len: Union[float, int] = 0.5, **kwargs) -> None: + if dims_c is None: + dims_c = [] + + super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len) + + + self.spline_base = BinnedSplineBase(dims_in, dims_c, **kwargs) + + num_params = sum(self.spline_base.parameter_counts.values()) + self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params) + self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params) + + def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """ + The full coupling consists of: + 1. Querying the parameter tensor from the subnetwork + 2. Splitting this tensor into the semantic parameters + 3. Constraining the parameters + 4. Performing the actual spline for each bin, given the parameters + """ + parameters = self.subnet1(u2) + parameters = self.spline_base.split_parameters(parameters, self.split_len1) + parameters = self.constrain_parameters(parameters) + + return self.spline_base.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev) + + def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + parameters = self.subnet2(u1) + parameters = self.spline_base.split_parameters(parameters, self.split_len2) + parameters = self.constrain_parameters(parameters) + + return self.spline_base.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev) + + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self.spline_base.constrain_parameters(parameters) + +class BinnedSplineBase(InvertibleModule): """ Base Class for Splines Implements input-binning, where bin knots are jointly predicted along with spline parameters by a non-invertible coupling subnetwork """ - def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, split_len: Union[float, int] = 0.5, - bins: int = 10, parameter_counts: Dict[str, int] = None, min_bin_sizes: Tuple[float] = (0.1, 0.1), - default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None: + def __init__(self, dims_in, dims_c=None, bins: int = 10, parameter_counts: Dict[str, int] = None, + min_bin_sizes: Tuple[float] = (0.1, 0.1), default_domain: Tuple[float] = (-3.0, 3.0, -3.0, 3.0)) -> None: """ Args: bins: number of bins to use @@ -35,7 +79,7 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp if parameter_counts is None: parameter_counts = {} - super().__init__(dims_in, dims_c, clamp=0.0, clamp_activation=lambda u: u, split_len=split_len) + super().__init__(dims_in, dims_c) assert bins >= 1, "need at least one bin" assert all(s >= 0 for s in min_bin_sizes), "minimum bin size cannot be negative" @@ -61,40 +105,9 @@ def __init__(self, dims_in, dims_c=None, subnet_constructor: callable = None, sp # merge parameter counts with child classes self.parameter_counts = {**default_parameter_counts, **parameter_counts} - num_params = sum(self.parameter_counts.values()) - self.subnet1 = subnet_constructor(self.split_len2 + self.condition_length, self.split_len1 * num_params) - self.subnet2 = subnet_constructor(self.split_len1 + self.condition_length, self.split_len2 * num_params) - - def _spline1(self, x1: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def _spline2(self, x2: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def _coupling1(self, x1: torch.Tensor, u2: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - """ - The full coupling consists of: - 1. Querying the parameter tensor from the subnetwork - 2. Splitting this tensor into the semantic parameters - 3. Constraining the parameters - 4. Performing the actual spline for each bin, given the parameters - """ - parameters = self.subnet1(u2) - parameters = self.split_parameters(parameters, self.split_len1) - parameters = self.constrain_parameters(parameters) - - return self.binned_spline(x=x1, parameters=parameters, spline=self._spline1, rev=rev) - - def _coupling2(self, x2: torch.Tensor, u1: torch.Tensor, rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: - parameters = self.subnet2(u1) - parameters = self.split_parameters(parameters, self.split_len2) - parameters = self.constrain_parameters(parameters) - - return self.binned_spline(x=x2, parameters=parameters, spline=self._spline2, rev=rev) - def split_parameters(self, parameters: torch.Tensor, split_len: int) -> Dict[str, torch.Tensor]: """ - Split network output into semantic parameters, as given by self.parameter_counts + Split parameter tensor into semantic parameters, as given by self.parameter_counts """ parameters = parameters.movedim(1, -1) parameters = parameters.reshape(*parameters.shape[:-1], split_len, -1) diff --git a/FrEIA/modules/splines/rational_quadratic.py b/FrEIA/modules/splines/rational_quadratic.py index c296375..6656fb8 100644 --- a/FrEIA/modules/splines/rational_quadratic.py +++ b/FrEIA/modules/splines/rational_quadratic.py @@ -1,17 +1,18 @@ -from typing import Dict, List, Tuple +from typing import Dict, Tuple, Callable, Iterable, List import torch +from torch import nn import torch.nn.functional as F import numpy as np -from .binned import BinnedSpline +from .binned import BinnedSplineBase, BinnedSpline class RationalQuadraticSpline(BinnedSpline): def __init__(self, *args, bins: int = 10, **kwargs): # parameter constraints count # 1. the derivative at the edge of each inner bin positive #bins - 1 - super().__init__(*args, **kwargs, bins=bins, parameter_counts={"deltas": bins - 1}) + super().__init__(*args, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs) def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: parameters = super().constrain_parameters(parameters) @@ -44,6 +45,65 @@ def _spline2(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bo return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev) +class ElementwiseRationalQuadraticSpline(BinnedSplineBase): + def __init__(self, dims_in, dims_c=[], subnet_constructor: Callable = None, + bins: int = 10, **kwargs) -> None: + super().__init__(dims_in, dims_c, bins=bins, parameter_counts={"deltas": bins - 1}, **kwargs) + + self.channels = dims_in[0][0] + self.condition_length = sum([dims_c[i][0] for i in range(len(dims_c))]) + self.conditional = len(dims_c) > 0 + + num_params = sum(self.parameter_counts.values()) + + + if self.conditional: + self.subnet = subnet_constructor(self.condition_length, self.channels * num_params) + else: + self.spline_parameters = nn.Parameter(torch.zeros(self.channels * num_params, *dims_in[0][1:])) + + + def constrain_parameters(self, parameters: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + parameters = super().constrain_parameters(parameters) + # we additionally want positive derivatives to preserve monotonicity + # the derivative must also match the tails at the spline boundaries + deltas = parameters["deltas"] + + # shifted softplus such that network output 0 -> delta = scale + shift = np.log(np.e - 1) + deltas = F.softplus(deltas + shift) + + # boundary condition: derivative is equal to affine scale at spline boundaries + scale = torch.sum(parameters["heights"], dim=-1, keepdim=True) / torch.sum(parameters["widths"], dim=-1, keepdim=True) + scale = scale.expand(*scale.shape[:-1], 2) + + deltas = torch.cat((deltas, scale), dim=-1).roll(1, dims=-1) + + parameters["deltas"] = deltas + + return parameters + + def output_dims(self, input_dims: List[Tuple[int]]) -> List[Tuple[int]]: + return input_dims + + def forward(self, x_or_z: Iterable[torch.Tensor], c: Iterable[torch.Tensor] = None, + rev: bool = False, jac: bool = True) \ + -> Tuple[Tuple[torch.Tensor], torch.Tensor]: + if self.conditional: + parameters = self.subnet(torch.cat(c, dim=1).float()) + else: + parameters = self.spline_parameters.unsqueeze(0).repeat_interleave(x_or_z[0].shape[0], dim=0) + parameters = self.split_parameters(parameters, self.channels) + parameters = self.constrain_parameters(parameters) + + y, jac = self.binned_spline(x=x_or_z[0], parameters=parameters, spline=self.spline, rev=rev) + return (y,), jac + + def spline(self, x: torch.Tensor, parameters: Dict[str, torch.Tensor], rev: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + left, right, bottom, top = parameters["left"], parameters["right"], parameters["bottom"], parameters["top"] + deltas_left, deltas_right = parameters["deltas_left"], parameters["deltas_right"] + return rational_quadratic_spline(x, left, right, bottom, top, deltas_left, deltas_right, rev=rev) + def rational_quadratic_spline(x: torch.Tensor, left: torch.Tensor, right: torch.Tensor, @@ -116,3 +176,5 @@ def rational_quadratic_spline(x: torch.Tensor, log_jac = torch.log(numerator) - torch.log(denominator) return out, log_jac + + diff --git a/docs/_build/html/FrEIA.framework.html b/docs/_build/html/FrEIA.framework.html index 880670d..dc31bdd 100644 --- a/docs/_build/html/FrEIA.framework.html +++ b/docs/_build/html/FrEIA.framework.html @@ -412,6 +412,20 @@

FrEIA.framework package

Approximate log Jacobian determinant via finite differences.

+
+
+plot(path: str, filename: str) None[source]#
+

Generates a plot of the GraphINN and stores it as pdf and dot file

+
+
Parameters:
+
    +
  • path – Directory to store the plots in. Must exist previous to plotting

  • +
  • filename – Name of the newly generated plots

  • +
+
+
+
+
@@ -637,6 +651,7 @@

FrEIA.framework packageGraphINN.get_module_by_name()
  • GraphINN.get_node_by_name()
  • GraphINN.log_jacobian_numerical()
  • +
  • GraphINN.plot()
  • InputNode

    Reshaping:

      @@ -388,6 +389,7 @@
    • LearnedElementwiseScaling

    • OrthogonalTransform

    • HouseholderPerm

    • +
    • ElementwiseRationalQuadraticSpline

    Fixed (non-learned) transforms:

  • InputNode
      @@ -667,7 +668,7 @@

      Indices and tables diff --git a/docs/_build/html/objects.inv b/docs/_build/html/objects.inv index ad2efa4..b160f8b 100644 Binary files a/docs/_build/html/objects.inv and b/docs/_build/html/objects.inv differ diff --git a/docs/_build/html/py-modindex.html b/docs/_build/html/py-modindex.html index 81aefbb..46d1ebd 100644 --- a/docs/_build/html/py-modindex.html +++ b/docs/_build/html/py-modindex.html @@ -397,7 +397,7 @@

      Python Module Index

      diff --git a/docs/_build/html/search.html b/docs/_build/html/search.html index 6d0dd3d..b78b4dd 100644 --- a/docs/_build/html/search.html +++ b/docs/_build/html/search.html @@ -396,7 +396,7 @@

      Search

      diff --git a/docs/_build/html/searchindex.js b/docs/_build/html/searchindex.js index fa9726e..ae36041 100644 --- a/docs/_build/html/searchindex.js +++ b/docs/_build/html/searchindex.js @@ -1 +1 @@ -Search.setIndex({"docnames": ["FrEIA.framework", "FrEIA.modules", "index", "tutorial/basic_concepts", "tutorial/examples", "tutorial/examples/bayes_flow", "tutorial/examples/convolutional", "tutorial/examples/fully_connected", "tutorial/examples/inv_unet", "tutorial/examples/training_loop_cinn", "tutorial/examples/training_loop_inn", "tutorial/graph_inns", "tutorial/invertible_operations", "tutorial/quickstart", "tutorial/sequential_inns", "tutorial/tips_tricks_faq", "tutorial/tutorial"], "filenames": ["FrEIA.framework.rst", "FrEIA.modules.rst", "index.rst", "tutorial/basic_concepts.rst", "tutorial/examples.rst", "tutorial/examples/bayes_flow.rst", "tutorial/examples/convolutional.rst", "tutorial/examples/fully_connected.rst", "tutorial/examples/inv_unet.rst", "tutorial/examples/training_loop_cinn.rst", "tutorial/examples/training_loop_inn.rst", "tutorial/graph_inns.rst", "tutorial/invertible_operations.rst", "tutorial/quickstart.rst", "tutorial/sequential_inns.rst", "tutorial/tips_tricks_faq.rst", "tutorial/tutorial.rst"], "titles": ["FrEIA.framework package", "FrEIA.modules package", "Welcome to FrEIA\u2019s documentation!", "Basic concepts", "Examples", "Bayes-flow", "Convolutional INN with invertible downsampling", "Small fully-connected INNs", "Invertible U-Net", "Training: MNIST conditional normalizing flow", "Training: CelebA normalizing flow", "Computation graph API", "Invertible Operations", "Quickstart guide", "Sequential API", "Tips & Tricks, FAQ", "Tutorial"], "terms": {"The": [0, 1, 3, 7, 11, 12, 14], "contain": [0, 1, 6, 12, 14, 15], "logic": 0, "us": [0, 1, 3, 6, 7, 11, 12, 13, 14, 15], "build": [0, 11, 12], "graph": [0, 2, 3, 6, 7, 12, 14, 16], "infer": [0, 1, 3], "order": [0, 1, 3, 11, 15], "node": [0, 1, 2, 3, 6, 11, 12], "have": [0, 1, 3, 6, 11, 12, 14, 15], "execut": 0, "forward": [0, 1, 2, 3, 11, 12, 14, 15], "backward": [0, 1, 12, 13], "direct": [0, 1, 3, 11, 13], "class": [0, 1, 11, 12, 14], "conditionnod": [0, 2, 11, 12], "dim": [0, 1, 6], "int": [0, 1], "name": [0, 1, 6, 11, 12], "none": [0, 1, 11], "sourc": [0, 1], "base": [0, 1, 2, 12], "special": [0, 12], "type": [0, 1], "repres": [0, 1, 11], "contit": 0, "input": [0, 1, 3, 6, 7, 11, 12, 14, 15], "intern": [0, 1], "network": [0, 1, 2, 7, 11, 13, 15], "insid": [0, 1, 13], "coupl": [0, 2, 7, 11, 13, 14, 15, 16], "layer": [0, 1, 15], "__init__": [0, 1, 2, 12], "build_modul": [0, 2], "condition_shap": 0, "input_shap": 0, "tupl": [0, 1, 11, 12, 14], "list": [0, 1, 11, 12, 14], "instanti": 0, "determin": [0, 1, 12, 13, 14], "output": [0, 1, 3, 6, 11, 12, 14], "dimens": [0, 1, 2, 4, 6, 11, 12, 14], "call": [0, 1, 14], "invertiblemodul": [0, 1, 2, 12], "output_dim": [0, 1, 2, 12], "graphinn": [0, 1, 2, 6, 11, 12, 14], "node_list": 0, "force_tuple_output": 0, "fals": [0, 1, 12], "verbos": 0, "thi": [0, 1, 3, 6, 11, 12, 13, 14, 15], "invert": [0, 2, 3, 4, 11, 13, 14, 16], "net": [0, 2, 4, 12], "itself": [0, 1, 11, 12], "It": [0, 1, 7, 11], "i": [0, 1, 3, 6, 7, 11, 12, 13, 14, 15], "subclass": [0, 1, 11], "support": 0, "same": [0, 1], "method": [0, 1, 3, 14], "ha": [0, 1, 7, 11, 15], "an": [0, 1, 3, 11, 12, 13, 14], "addit": [0, 1, 12], "option": [0, 1, 12], "rev": [0, 1, 11, 12, 13, 14], "which": [0, 1, 6, 7, 11, 12, 13, 14], "can": [0, 1, 3, 6, 7, 11, 12, 14, 15], "comput": [0, 1, 2, 3, 6, 7, 14, 16], "revers": [0, 1, 12, 13], "pass": [0, 1, 11, 12, 13], "jac": [0, 1, 12, 14], "addition": 0, "log": [0, 1, 12, 13, 14], "invers": [0, 1, 3, 11, 14], "jacobian": [0, 1, 12, 13, 14], "paramet": [0, 1, 7, 12, 13, 15], "dims_in": [0, 1, 11, 12, 13, 14], "specifi": [0, 1, 11, 14], "shape": [0, 1, 3, 7, 12], "oper": [0, 1, 2, 3, 7, 11, 13, 14, 16], "shape_x_0": [0, 1], "shape_x_1": [0, 1], "dims_c": [0, 1, 12], "condit": [0, 1, 2, 3, 4, 11, 12, 14], "get_module_by_nam": [0, 2], "return": [0, 1, 7, 11, 12, 13, 14], "first": [0, 1, 11, 12, 14, 15], "provid": [0, 11, 12, 13, 14], "get_node_by_nam": [0, 2], "log_jacobian_numer": [0, 2], "x": [0, 1, 12, 13, 14], "c": [0, 1, 11, 12, 14], "h": 0, "0": [0, 1, 6, 7, 12, 13, 14], "0001": 0, "approxim": [0, 2], "via": [0, 1], "finit": 0, "differ": [0, 1, 7], "inputnod": [0, 2, 6, 11, 12], "data": [0, 1, 11, 12, 13, 14], "whole": [0, 1, 11, 15], "when": [0, 1, 12], "run": [0, 12], "iter": [0, 1, 15], "module_typ": [0, 11], "module_arg": [0, 11], "dict": [0, 1], "object": [0, 1], "one": [0, 1, 11, 12, 14, 15], "transform": [0, 2, 6, 12, 13, 15], "arbitrari": [0, 3], "number": [0, 1, 3, 7, 11, 12, 15], "user": [0, 1], "underli": 0, "parse_input": [0, 2], "convert": 0, "canon": 0, "format": 0, "three": [0, 12], "form": [0, 1, 12], "singl": [0, 1, 6, 7, 11], "taken": [0, 1], "idx": 0, "each": [0, 1, 3, 11, 12, 14], "all": [0, 1, 6, 7, 11, 12, 14], "ar": [0, 1, 3, 6, 11, 12, 14, 15], "last": [0, 15], "outputnod": [0, 2, 6, 11, 12], "in_nod": 0, "reversiblegraphnet": [0, 2], "ind_in": 0, "ind_out": 0, "true": [0, 1, 3, 7, 11, 12, 13, 14], "reversiblesequenti": [0, 2, 13], "sequenceinn": [0, 2, 7, 12, 13, 14], "simpler": [0, 12], "than": [0, 1, 12], "onli": [0, 1, 7, 11, 12, 14, 15], "sequenti": [0, 2, 7, 11, 12, 13, 16], "seri": [0, 1, 7], "split": [0, 1, 2, 6, 11, 14], "merg": [0, 1, 11, 14], "branch": 0, "off": [0, 1, 6], "append": [0, 2, 6, 7, 12, 13, 14], "add": [0, 1, 15], "new": [0, 3], "block": [0, 2, 11, 13, 14, 15, 16], "more": [0, 1, 3, 11, 12, 13], "simpl": [0, 2, 4, 12, 13, 14], "wai": [0, 1, 12], "approach": 0, "For": [0, 1, 3, 6, 11, 13, 14, 15], "exampl": [0, 1, 2, 3, 7, 11, 12, 13, 14], "inn": [0, 1, 2, 3, 4, 11, 12, 13, 14, 15], "channel": [0, 1, 15], "dims_h": 0, "dims_w": 0, "rang": [0, 6, 7, 12, 13, 14], "n_block": 0, "allinoneblock": [0, 1, 2, 7, 12, 13, 14], "clamp": [0, 1, 6, 12], "2": [0, 1, 2, 4, 6, 11, 12, 13, 15], "permute_soft": [0, 1, 7, 13], "haardownsampl": [0, 1, 2], "so": [0, 1, 3, 6, 11, 12, 14, 15], "module_class": 0, "cond": [0, 7, 11, 12, 14], "cond_shap": [0, 7, 14], "kwarg": [0, 1], "from": [0, 1, 11, 13, 14], "index": [0, 1, 2, 11], "need": [0, 1, 3, 12], "tensor": [0, 1, 11, 12, 13], "further": [0, 6], "keyword": [0, 14], "argument": [0, 1, 11, 12, 14], "constructor": [0, 1, 11], "see": [0, 1, 7, 11, 12, 13, 14], "torch": [1, 7, 11, 12, 13, 14, 15], "nn": [1, 7, 11, 12, 13, 14, 15], "thing": 1, "compar": 1, "staticmethod": 1, "otuput_dim": 1, "nicecouplingblock": [1, 2], "rnvpcouplingblock": [1, 2], "glowcouplingblock": [1, 2, 6, 11, 12], "gincouplingblock": [1, 2], "affinecouplingonesid": [1, 2, 11], "conditionalaffinetransform": [1, 2], "irevnetdownsampl": [1, 2, 6], "irevnetupsampl": [1, 2], "haarupsampl": [1, 2], "flatten": [1, 2, 6, 14], "concat": [1, 2, 6, 11], "actnorm": [1, 2, 7, 11, 12], "iresnetlay": [1, 2], "invautoact": [1, 2], "invautoactfix": 1, "invautoacttwosid": [1, 2], "invautoconv2d": [1, 2], "invautofc": [1, 2], "learnedelementwisesc": [1, 2], "orthogonaltransform": [1, 2], "householderperm": [1, 2], "permuterandom": [1, 2, 6, 11], "fixedlineartransform": [1, 2], "fixed1x1conv": [1, 2], "invertiblesigmoid": [1, 2], "given": [1, 11, 12, 14], "instanc": [1, 11, 14], "some": [1, 3, 6, 15], "shall": 1, "its": [1, 11], "recov": 1, "appli": [1, 7], "mode": 1, "confus": 1, "pytorch": [1, 2, 3, 14], "gradient": [1, 15], "randn": [1, 11, 12, 13], "batch_siz": 1, "dim_count": 1, "condition_dim": 1, "z": [1, 12, 13, 14], "x_rev": [1, 12], "jac_rev": 1, "det": [1, 14], "j": 1, "left": 1, "frac": 1, "partial": 1, "f": [1, 6], "right": 1, "1": [1, 3, 6, 7, 11, 12, 13], "Then": 1, "allclos": 1, "x_or_z": 1, "bool": 1, "perform": [1, 3, 7, 11, 15], "default": 1, "through": [1, 6, 11, 14], "note": [1, 7, 11, 12], "implement": [1, 3, 11, 12], "must": [1, 11, 12], "valid": 1, "punish": 1, "latter": 1, "recommend": [1, 11], "trivial": [1, 12], "follow": [1, 3, 6, 7, 11, 14], "convent": 1, "consist": [1, 14], "evalu": 1, "let": 1, "": [1, 7, 12], "make": [1, 6, 12], "precis": 1, "function": [1, 12], "ani": [1, 3, 11, 14], "arrai": 1, "like": [1, 11], "associ": 1, "log_jacobian": [1, 2], "arg": 1, "deprec": 1, "doe": [1, 3, 12], "noth": [1, 12], "except": [1, 11], "rais": 1, "warn": 1, "input_dim": [1, 12, 14], "dure": [1, 3], "construct": [1, 2, 11, 12, 14], "A": [1, 11, 12, 14], "entri": [1, 14], "even": [1, 3, 6, 12], "give": [1, 11, 14], "exclud": 1, "batch": [1, 11], "receiv": [1, 14], "32x32": [1, 11], "pixel": 1, "rgb": [1, 11], "imag": [1, 6, 11, 14], "would": [1, 11, 12, 14], "3": [1, 6, 7, 11, 12, 14, 15], "32": [1, 6, 11, 14], "structur": [1, 2, 7, 13], "half": 1, "valu": 1, "should": [1, 7, 11, 12, 14, 15], "16": 1, "up": [1, 12], "implementor": 1, "ensur": 1, "total": 1, "element": [1, 12], "subnet_constructor": [1, 6, 7, 11, 12, 13, 14], "callabl": 1, "affine_clamp": 1, "float": [1, 12], "gin_block": 1, "global_affine_init": 1, "global_affine_typ": 1, "str": 1, "softplu": 1, "learned_householder_permut": 1, "reverse_permut": 1, "combin": [1, 7], "most": [1, 7, 12], "common": 1, "normal": [1, 2, 4, 13], "flow": [1, 2, 4], "similar": 1, "model": [1, 13], "affin": [1, 7, 11, 12, 13], "permut": [1, 7, 11, 12, 15], "global": [1, 12], "also": [1, 3, 11, 12], "gin": 1, "household": 1, "pre": 1, "includ": [1, 12, 13], "soft": 1, "mechan": 1, "real": 1, "nvp": 1, "y": 1, "v": 1, "r": 1, "psi": 1, "s_": 1, "mathrm": 1, "odot": 1, "big": 1, "t_": 1, "e": [1, 3, 6, 11, 12, 15], "below": [1, 11, 12, 13, 14], "reflect": 1, "matrix": 1, "togeth": [1, 7, 12, 14], "x_1": 1, "x_2": 1, "along": [1, 3, 11, 12], "two": [1, 12], "halv": 1, "u": [1, 2, 4], "u_1": 1, "u_2": 1, "exp": [1, 12], "alpha": 1, "tanh": 1, "t": [1, 12], "becaus": [1, 6, 12], "prevent": [1, 15], "explod": 1, "exponenti": [1, 12], "hyperparamet": [1, 12], "adjust": 1, "channels_in": 1, "channels_out": 1, "predict": [1, 12], "coeffici": [1, 12], "multipl": [1, 3, 6, 12, 14], "befor": 1, "abov": [1, 11, 14], "turn": 1, "sorrenson": 1, "et": 1, "al": 1, "2019": 1, "volum": 1, "preserv": 1, "initi": [1, 2, 3, 11, 12, 15], "scale": [1, 2, 12], "sigmoid": 1, "defin": [1, 2, 3, 7, 11, 13, 16], "activ": 1, "beta": 1, "whether": [1, 3], "sampl": [1, 13], "n": [1, 3], "hard": [1, 15], "instead": [1, 11, 12, 15], "veri": [1, 13, 15], "slow": 1, "work": [1, 3, 6, 12, 14, 15], "512": [1, 7, 13], "larg": 1, "dubiou": 1, "actual": 1, "help": 1, "introduc": 1, "putzki": 1, "split_len": 1, "5": [1, 11, 13, 14], "nice": [1, 12], "dinh": 1, "2015": 1, "design": [1, 11], "2d": [1, 7], "3d": 1, "4d": 1, "residu": 1, "subnetwork": [1, 12, 14, 15], "ad": [1, 14], "docstr": 1, "factori": 1, "signatur": 1, "dims_out": [1, 11, 12, 13, 14], "result": [1, 3, 11], "take": [1, 11, 12], "tutori": [1, 13], "clamp_activ": 1, "atan": [1, 12], "realnvp": [1, 3], "2017": 1, "minor": 1, "checkerboard": 1, "prepend": 1, "i_revnet_downsampl": 1, "both": [1, 3], "four": 1, "compon": [1, 12], "amplif": 1, "attenu": 1, "string": 1, "recogn": 1, "behav": 1, "origin": [1, 11, 14], "paper": [1, 11], "custom": [1, 2, 13, 16], "map": 1, "inf": 1, "glow": [1, 11, 12], "part": [1, 3, 6], "1x1": [1, 15], "convolut": [1, 2, 4, 7, 15], "etc": [1, 3, 11, 14], "onc": 1, "jointli": 1, "s_i": 1, "t_i": 1, "separ": 1, "reduc": 1, "cost": 1, "speed": 1, "constrain": 1, "achiev": 1, "subtract": 1, "mean": [1, 13, 14, 15], "while": 1, "still": 1, "power": 1, "within": [1, 3], "slightli": [1, 12], "publish": 1, "final": 1, "sum": [1, 12, 13], "zero": [1, 15], "There": 1, "wa": 1, "found": [1, 11], "between": [1, 12, 14, 15], "practic": 1, "guarante": 1, "might": 1, "stabl": [1, 12], "certain": 1, "case": [1, 12, 14], "In": [1, 12, 14], "where": 1, "random": [1, 11, 12], "orthogon": [1, 15], "after": [1, 6], "everi": [1, 12, 15], "restrict": 1, "simplifi": 1, "One": 1, "spade": 1, "park": 1, "legacy_backend": 1, "spatial": 1, "downsampl": [1, 2, 4], "revnet": 1, "group": 1, "neighbor": 1, "reorder": 1, "time": [1, 12], "pattern": 1, "jacobsen": 1, "2018": 1, "If": [1, 11, 14, 15], "concaten": [1, 3, 6], "adapt": 1, "github": [1, 11, 12], "com": 1, "jhjacobsen": 1, "blob": 1, "master": 1, "model_util": 1, "py": 1, "usual": [1, 12], "slower": 1, "gpu": 1, "stride": 1, "kernel": 1, "patch": 1, "a1": 1, "b1": 1, "a2": 1, "b2": 1, "c1": 1, "c2": 1, "a3": 1, "b3": 1, "order_by_wavelet": 1, "gener": [1, 3, 7, 11, 12, 14], "complet": [1, 15], "irrelev": [1, 11], "unless": 1, "certaint": 1, "subset": 1, "suppos": 1, "extract": 1, "detail": [1, 6, 12, 13], "transpos": 1, "expect": 1, "rebal": 1, "haar": 1, "wavelet": 1, "4": [1, 6, 15], "width": 1, "height": 1, "averag": 1, "vertic": 1, "horizont": 1, "diagon": 1, "v1": 1, "h1": 1, "d1": 1, "those": 1, "v2": 1, "h2": 1, "d2": 1, "set": [1, 11], "g": [1, 3, 6, 11, 15], "allow": [1, 3], "quarter": 1, "isol": 1, "exist": [1, 3], "how": [1, 11], "multipli": [1, 12], "factor": 1, "accordingli": 1, "stabil": [1, 15], "mai": [1, 15], "increas": 1, "been": 1, "concatent": 1, "higher": [1, 6], "frequenc": 1, "d": [1, 12], "target_dim": 1, "target": 1, "12": [1, 6, 7, 15], "necessarili": [1, 12], "sensibl": 1, "meaning": 1, "sequenc": [1, 14], "section_s": [1, 6], "n_section": 1, "incom": 1, "correspond": 1, "init": 1, "attribut": 1, "describ": 1, "check": [1, 11], "size": [1, 3, 6, 12], "dimension": [1, 12], "compat": 1, "handl": 1, "automat": 1, "setup": 1, "preced": [1, 11], "over": [1, 11], "section": [1, 11], "doesn": 1, "creat": 1, "slack": 1, "equal": 1, "close": 1, "numpi": 1, "array_split": 1, "count": 1, "ident": [1, 11, 12, 15], "sens": 1, "init_data": 1, "techniqu": 1, "kingma": 1, "http": [1, 11, 12], "arxiv": 1, "org": 1, "ab": [1, 11, 14], "1807": 1, "03039": 1, "tradit": 1, "standard": [1, 13, 15], "deviat": 1, "thei": [1, 3, 11, 12, 15], "treat": 1, "learnabl": 1, "interspers": 1, "throughout": 1, "intermedi": [1, 3], "train": [1, 2, 12, 13, 15], "start": [1, 12, 15], "avoid": 1, "just": [1, 6, 11, 14], "wise": 1, "bia": 1, "load_state_dict": [1, 2], "state_dict": [1, 3, 14], "strict": 1, "copi": [1, 12], "buffer": 1, "descend": 1, "kei": 1, "exactli": 1, "match": 1, "persist": 1, "strictli": 1, "enforc": 1, "missing_kei": 1, "miss": 1, "unexpected_kei": 1, "unexpect": 1, "namedtupl": 1, "field": 1, "properti": [1, 11], "slope_init": 1, "nonlinear": 1, "analog": 1, "leaki": 1, "relu": [1, 7, 11, 12, 13, 14], "slope": 1, "symmetr": 1, "posit": [1, 12], "neg": [1, 13, 14], "side": 1, "geq": 1, "impli": 1, "oslash": 1, "intput": 1, "dimenison": 1, "account": 1, "init_po": 1, "init_neg": 1, "space": 1, "stai": 1, "alpha_": 1, "init_scal": 1, "unlik": 1, "realli": 1, "individu": [1, 15], "To": [1, 3, 11, 13], "correction_interv": 1, "256": [1, 7, 11], "term": [1, 12], "free": 1, "weight": [1, 3, 13, 15], "project": 1, "back": [1, 11], "stiefel": 1, "manifold": 1, "matric": 1, "regular": 1, "interv": 1, "With": 1, "rx": 1, "b": 1, "cdot": 1, "pi": [1, 12], "mani": [1, 6, 11, 14, 15], "step": [1, 13], "perfectli": [1, 15], "n_reflect": 1, "fast": 1, "product": 1, "mathiesen": 1, "2020": 1, "invertibleworkshop": 1, "io": [1, 11, 12], "accepted_pap": 1, "pdf": 1, "10": [1, 7, 11, 12, 14], "1d": [1, 11], "without": [1, 3, 12, 14], "vector": [1, 6, 11], "conatin": 1, "backpropag": [1, 13], "subsequ": 1, "independ": 1, "due": 1, "reason": 1, "randomli": 1, "kept": 1, "seed": [1, 6, 11], "multi": 1, "dimenion": [1, 6, 15], "rng": 1, "do": [1, 12, 15], "rese": 1, "m": 1, "linear": [1, 7, 11, 12, 13, 14], "tesor": 1, "mx": 1, "offset": 1, "length": 1, "squar": 1, "effect": [1, 12], "muplitpl": 1, "across": 1, "trainabl": 1, "fulli": [1, 2, 4, 6, 12, 15], "connect": [1, 2, 3, 4, 6, 12, 15], "autoencod": 1, "1802": 1, "06869": 1, "tranpos": 1, "reconstruct": 1, "loss": [1, 13], "converg": 1, "howev": 1, "becuas": [1, 12], "invauto": 1, "asymptot": 1, "limit": [1, 12], "ouput": 1, "integ": 1, "kernel_s": 1, "pad": [1, 7], "variant": 1, "convlut": 1, "choos": 1, "retain": 1, "therefor": [1, 11, 12], "respons": 1, "depend": [1, 3, 11, 12], "internal_s": 1, "n_internal_lay": 1, "jacobian_iter": 1, "20": [1, 11], "hutchinson_sampl": 1, "fixed_point_iter": 1, "50": 1, "lipschitz_iter": 1, "lipschitz_batchs": 1, "spectral_norm_max": 1, "8": [1, 7, 13, 14], "resnet": 1, "architectur": [1, 2, 3, 6, 12, 14], "propos": 1, "1811": 1, "00995": 1, "lipschitz_correct": [1, 2], "gaussianmixturemodel": [1, 2], "gaussian": [1, 7], "mixtur": [1, 7], "covari": 1, "parameter": 1, "suppli": [1, 14], "come": 1, "extern": 1, "feed": [1, 12, 14, 15], "gmm": 1, "normalize_weight": [1, 2], "w": 1, "indic": 1, "pick_mixture_compon": [1, 2], "latent": 1, "variabl": [1, 13], "chosen": [1, 12], "k": [1, 6, 7, 13, 14], "point": 1, "code": [1, 3, 12, 13], "simultan": 1, "mathemat": 1, "deriv": [1, 11], "technic": 1, "report": 1, "densiti": 1, "full": [1, 2, 6, 13], "static": 1, "nll_loss": [1, 2], "likelihood": [1, 13], "n_compon": 1, "n_dim": [1, 13], "nll_upper_bound": [1, 2], "numer": 1, "upper": 1, "bound": 1, "softmax": 1, "unnorm": 1, "probabl": 1, "decis": 1, "framework": [2, 3, 7, 11, 13, 14], "easili": [2, 7], "modul": [2, 7, 11, 12, 13, 14], "quickstart": 2, "guid": 2, "basic": [2, 11, 12, 13, 16], "concept": [2, 11, 16], "api": [2, 7, 16], "tip": [2, 16], "trick": [2, 16], "faq": [2, 16], "small": [2, 4], "mnist": [2, 4, 14], "bay": [2, 4], "loop": [2, 11, 13], "celeba": [2, 4], "content": [2, 6], "abstract": 2, "templat": [2, 12], "reshap": [2, 15], "topologi": 2, "other": [2, 3, 11, 12, 15], "learn": [2, 7, 12, 13], "fix": [2, 3, 12], "non": 2, "semi": 2, "search": 2, "page": [2, 14], "why": 3, "freia": [3, 7, 11, 12, 13, 14], "100": [3, 11, 13], "line": 3, "That": 3, "correct": 3, "long": 3, "loos": 3, "end": 3, "entir": 3, "we": [3, 6, 7, 11, 12, 13, 14], "consisit": 3, "consid": 3, "want": [3, 11], "complic": [3, 12], "skip": [3, 6], "effici": [3, 7], "prototyp": 3, "experiment": 3, "task": 3, "As": 3, "mind": [3, 15], "requir": [3, 12], "keep": [3, 15], "track": 3, "edg": 3, "store": 3, "them": [3, 6, 12], "until": 3, "tool": 3, "cuda": [3, 14], "dataparallel": 3, "worri": 3, "correctli": 3, "interfac": 3, "main": 3, "address": 3, "cifar10": 6, "encod": 6, "local": 6, "rest": 6, "semant": 6, "import": [6, 7, 13, 14, 15], "moder": 6, "becom": 6, "infeas": 6, "depth": 6, "enocd": 6, "nois": [6, 13, 15], "earli": [6, 15], "machineri": 6, "ff": [6, 7, 11, 12, 13, 14], "ndim_x": 6, "resolut": 6, "fm": [6, 7, 11, 12, 13, 14], "subnet_conv": [6, 7], "conv_high_res_": 6, "permute_high_res_": 6, "lower": 6, "subnet": [6, 11, 13, 15], "subnet_conv_1x1": [6, 7], "els": [6, 12, 15], "conv_low_res_": 6, "permute_low_res_": 6, "split_nod": 6, "subnet_fc": [6, 7, 13, 14], "fully_connected_": 6, "permute_": 6, "get": [6, 11, 13, 15], "out0": [6, 11, 12], "out1": [6, 11, 12], "concat1d": 6, "conv_inn": 6, "These": [7, 12], "declar": 7, "def": [7, 11, 12, 13, 14], "c_in": 7, "c_out": 7, "conv2d": 7, "abl": 7, "distribut": [7, 13], "visual": 7, "sinc": 7, "chain": [7, 13], "cinn": [7, 11, 14], "quit": 7, "well": [7, 13, 15], "particularli": 7, "respect": 7, "again": [7, 14], "collect": [7, 13], "28": 7, "specif": 11, "in1": [11, 12], "omit": 11, "principl": 11, "appear": 11, "error": 11, "messag": 11, "out": [11, 12, 14], "equival": 11, "you": [11, 15], "cover": 11, "later": 11, "particular": [11, 15], "what": 11, "look": 11, "perm": 11, "Or": 11, "merge2": 11, "split2": 11, "merge1": 11, "whose": 11, "sever": 11, "featur": [11, 12], "conveni": 11, "directli": [11, 12], "sinlg": 11, "far": [11, 14], "shown": 11, "closer": 11, "discuss": [11, 12], "document": [11, 12], "under": 11, "vll": [11, 12], "hd": [11, 12], "html": [11, 12], "dictionari": 11, "accept": 11, "could": [11, 12], "111": 11, "empti": 11, "rule": 11, "in2": 11, "42": 11, "split1": 11, "concat1": 11, "concat2": 11, "output1": 11, "output2": 11, "example_inn": 11, "dummi": 11, "x1": [11, 12], "x2": [11, 12], "z1": [11, 12], "z2": [11, 12], "log_jac_det": [11, 12, 13], "x1_inv": 11, "x2_inv": 11, "log_jac_det_inv": 11, "assert": [11, 14], "max": [11, 14], "1e": [11, 14, 15], "commonli": 12, "submodul": 12, "here": [12, 14], "_build": 12, "rnvp": 12, "merit": 12, "smaller": 12, "known": [12, 14], "perhap": 12, "hand": 12, "anew": 12, "modifi": 12, "rather": 12, "lot": 12, "sound": 12, "fc_constr": 12, "128": [12, 14, 15], "lead": 12, "much": 12, "enabl": 12, "larger": 12, "rate": 12, "therebi": 12, "good": 12, "place": 12, "variou": 12, "save": 12, "effort": 12, "written": 12, "extens": 12, "refer": [12, 14], "inform": [12, 13], "illustr": 12, "definit": 12, "either": 12, "second": 12, "swap": 12, "otherwis": 12, "calcul": [12, 13], "fixedrandomelementwisemultipli": 12, "self": 12, "super": 12, "random_factor": 12, "randint": 12, "conditionalswap": 12, "don": 12, "x1_new": 12, "x2_new": 12, "usag": [12, 13], "batchsiz": [12, 13], "log_jac_det_rev": 12, "input_1": 12, "input_2": 12, "mult_1": 12, "cond_swap": 12, "conditional_swap": 12, "mult_2": 12, "output_1": 12, "output_2": 12, "x1_rev": 12, "x2_rev": 12, "_": [12, 13], "jump": 13, "straight": 13, "moon": 13, "dataset": 13, "sklearn": 13, "make_moon": 13, "optim": 13, "adam": [13, 15], "lr": 13, "001": 13, "1000": 13, "zero_grad": 13, "label": [13, 14], "n_sampl": 13, "05": 13, "prior": 13, "updat": 13, "explicitli": 14, "write": 14, "784": 14, "cifar": 14, "simpli": 14, "obtain": 14, "x_inv": 14, "jac_inv": 14, "child": 14, "presuppos": 14, "next": 14, "now": 14, "imagin": 14, "hot": 14, "cond_dim": 14, "tell": 14, "one_hot_label": 14, "stochast": 15, "descent": 15, "clip": 15, "experienc": 15, "instabl": 15, "util": 15, "clip_grad_norm_": 15, "slight": 15, "spars": 15, "quantiz": 15, "correl": 15, "xavier": 15, "unstabl": 15, "your": 15, "deep": 15, "30": 15, "nan": 15, "forget": 15, "shallow": 15, "wide": 15, "neuron": 15, "64": 15, "conv": 15, "think": 15, "huge": 15, "being": 15, "said": 15, "roughli": 15, "too": 15, "break": 15, "oppos": 15, "kind": 15, "seem": 15, "qualiti": 15, "constitut": 15, "third": 15}, "objects": {"": [[2, 0, 0, "-", "FrEIA"]], "FrEIA": [[0, 0, 0, "-", "framework"], [1, 0, 0, "-", "modules"]], "FrEIA.framework": [[0, 1, 1, "", "ConditionNode"], [0, 1, 1, "", "GraphINN"], [0, 1, 1, "", "InputNode"], [0, 1, 1, "", "Node"], [0, 1, 1, "", "OutputNode"], [0, 1, 1, "", "ReversibleGraphNet"], [0, 1, 1, "", "ReversibleSequential"], [0, 1, 1, "", "SequenceINN"]], "FrEIA.framework.ConditionNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.GraphINN": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "get_module_by_name"], [0, 2, 1, "", "get_node_by_name"], [0, 2, 1, "", "log_jacobian_numerical"]], "FrEIA.framework.InputNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.Node": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"], [0, 2, 1, "", "parse_inputs"]], "FrEIA.framework.OutputNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.ReversibleGraphNet": [[0, 2, 1, "", "__init__"]], "FrEIA.framework.ReversibleSequential": [[0, 2, 1, "", "__init__"]], "FrEIA.framework.SequenceINN": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "append"]], "FrEIA.modules": [[1, 1, 1, "", "ActNorm"], [1, 1, 1, "", "AffineCouplingOneSided"], [1, 1, 1, "", "AllInOneBlock"], [1, 1, 1, "", "Concat"], [1, 1, 1, "", "ConditionalAffineTransform"], [1, 1, 1, "", "Fixed1x1Conv"], [1, 1, 1, "", "FixedLinearTransform"], [1, 1, 1, "", "Flatten"], [1, 1, 1, "", "GINCouplingBlock"], [1, 1, 1, "", "GLOWCouplingBlock"], [1, 1, 1, "", "GaussianMixtureModel"], [1, 1, 1, "", "HaarDownsampling"], [1, 1, 1, "", "HaarUpsampling"], [1, 1, 1, "", "HouseholderPerm"], [1, 1, 1, "", "IResNetLayer"], [1, 1, 1, "", "IRevNetDownsampling"], [1, 1, 1, "", "IRevNetUpsampling"], [1, 1, 1, "", "InvAutoAct"], [1, 1, 1, "", "InvAutoActTwoSided"], [1, 1, 1, "", "InvAutoConv2D"], [1, 1, 1, "", "InvAutoFC"], [1, 1, 1, "", "InvertibleModule"], [1, 1, 1, "", "InvertibleSigmoid"], [1, 1, 1, "", "LearnedElementwiseScaling"], [1, 1, 1, "", "NICECouplingBlock"], [1, 1, 1, "", "OrthogonalTransform"], [1, 1, 1, "", "PermuteRandom"], [1, 1, 1, "", "RNVPCouplingBlock"], [1, 1, 1, "", "Reshape"], [1, 1, 1, "", "Split"]], "FrEIA.modules.ActNorm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "initialize"], [1, 2, 1, "", "load_state_dict"], [1, 3, 1, "", "scale"]], "FrEIA.modules.AffineCouplingOneSided": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.AllInOneBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Concat": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.ConditionalAffineTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Fixed1x1Conv": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.FixedLinearTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Flatten": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GINCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GLOWCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GaussianMixtureModel": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "nll_loss"], [1, 2, 1, "", "nll_upper_bound"], [1, 2, 1, "", "normalize_weights"], [1, 2, 1, "", "pick_mixture_component"]], "FrEIA.modules.HaarDownsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.HaarUpsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.HouseholderPerm": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.IResNetLayer": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "lipschitz_correction"]], "FrEIA.modules.IRevNetDownsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.IRevNetUpsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoAct": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoActTwoSided": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoConv2D": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoFC": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvertibleModule": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "log_jacobian"], [1, 2, 1, "", "output_dims"]], "FrEIA.modules.InvertibleSigmoid": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.LearnedElementwiseScaling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.NICECouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.OrthogonalTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.PermuteRandom": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.RNVPCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Reshape": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Split": [[1, 2, 1, "", "__init__"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:property"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "property", "Python property"]}, "titleterms": {"freia": [0, 1, 2], "framework": 0, "packag": [0, 1, 2], "modul": [0, 1], "content": 0, "abstract": 1, "templat": 1, "coupl": [1, 12], "block": [1, 12], "reshap": 1, "graph": [1, 11], "topologi": 1, "other": 1, "learn": 1, "transform": 1, "fix": 1, "non": 1, "approxim": 1, "semi": 1, "invert": [1, 6, 8, 12], "welcom": 2, "": 2, "document": 2, "tutori": [2, 16], "indic": 2, "tabl": 2, "basic": 3, "concept": 3, "exampl": 4, "network": 4, "architectur": 4, "full": 4, "train": [4, 9, 10], "loop": 4, "bay": 5, "flow": [5, 9, 10], "convolut": 6, "inn": [6, 7], "downsampl": 6, "small": 7, "fulli": 7, "connect": 7, "simpl": 7, "2": 7, "dimens": 7, "condit": [7, 9], "mnist": [7, 9], "u": 8, "net": 8, "normal": [9, 10], "celeba": 10, "comput": 11, "api": [11, 14], "oper": 12, "defin": 12, "custom": 12, "quickstart": 13, "guid": 13, "sequenti": 14, "tip": 15, "trick": 15, "faq": 15}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"FrEIA.framework package": [[0, "freia-framework-package"]], "Module contents": [[0, "module-FrEIA.framework"]], "FrEIA.modules package": [[1, "module-FrEIA.modules"]], "Abstract template": [[1, "abstract-template"]], "Coupling blocks": [[1, "coupling-blocks"], [12, "coupling-blocks"]], "Reshaping": [[1, "reshaping"]], "Graph topology": [[1, "graph-topology"]], "Other learned transforms": [[1, "other-learned-transforms"]], "Fixed (non-learned) transforms": [[1, "fixed-non-learned-transforms"]], "Approximately- or semi-invertible transforms": [[1, "approximately-or-semi-invertible-transforms"]], "Welcome to FrEIA\u2019s documentation!": [[2, "module-FrEIA"]], "Tutorial": [[2, "tutorial"], [16, "tutorial"]], "Package Documentation": [[2, "package-documentation"]], "Indices and tables": [[2, "indices-and-tables"]], "Basic concepts": [[3, "basic-concepts"]], "Examples": [[4, "examples"]], "Network architectures": [[4, "network-architectures"]], "Full training loops": [[4, "full-training-loops"]], "Bayes-flow": [[5, "bayes-flow"]], "Convolutional INN with invertible downsampling": [[6, "convolutional-inn-with-invertible-downsampling"]], "Small fully-connected INNs": [[7, "small-fully-connected-inns"]], "Simple INN in 2 dimensions": [[7, "simple-inn-in-2-dimensions"]], "Conditional INN for MNIST": [[7, "conditional-inn-for-mnist"]], "Invertible U-Net": [[8, "invertible-u-net"]], "Training: MNIST conditional normalizing flow": [[9, "training-mnist-conditional-normalizing-flow"]], "Training: CelebA normalizing flow": [[10, "training-celeba-normalizing-flow"]], "Computation graph API": [[11, "computation-graph-api"]], "Invertible Operations": [[12, "invertible-operations"]], "Defining custom invertible operations": [[12, "defining-custom-invertible-operations"]], "Quickstart guide": [[13, "quickstart-guide"]], "Sequential API": [[14, "sequential-api"]], "Tips & Tricks, FAQ": [[15, "tips-tricks-faq"]]}, "indexentries": {"conditionnode (class in freia.framework)": [[0, "FrEIA.framework.ConditionNode"]], "freia.framework": [[0, "module-FrEIA.framework"]], "graphinn (class in freia.framework)": [[0, "FrEIA.framework.GraphINN"]], "inputnode (class in freia.framework)": [[0, "FrEIA.framework.InputNode"]], "node (class in freia.framework)": [[0, "FrEIA.framework.Node"]], "outputnode (class in freia.framework)": [[0, "FrEIA.framework.OutputNode"]], "reversiblegraphnet (class in freia.framework)": [[0, "FrEIA.framework.ReversibleGraphNet"]], "reversiblesequential (class in freia.framework)": [[0, "FrEIA.framework.ReversibleSequential"]], "sequenceinn (class in freia.framework)": [[0, "FrEIA.framework.SequenceINN"]], "__init__() (freia.framework.conditionnode method)": [[0, "FrEIA.framework.ConditionNode.__init__"]], "__init__() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.__init__"]], "__init__() (freia.framework.inputnode method)": [[0, "FrEIA.framework.InputNode.__init__"]], "__init__() (freia.framework.node method)": [[0, "FrEIA.framework.Node.__init__"]], "__init__() (freia.framework.outputnode method)": [[0, "FrEIA.framework.OutputNode.__init__"]], "__init__() (freia.framework.reversiblegraphnet method)": [[0, "FrEIA.framework.ReversibleGraphNet.__init__"]], "__init__() (freia.framework.reversiblesequential method)": [[0, "FrEIA.framework.ReversibleSequential.__init__"]], "__init__() (freia.framework.sequenceinn method)": [[0, "FrEIA.framework.SequenceINN.__init__"]], "append() (freia.framework.sequenceinn method)": [[0, "FrEIA.framework.SequenceINN.append"]], "build_module() (freia.framework.conditionnode method)": [[0, "FrEIA.framework.ConditionNode.build_module"]], "build_module() (freia.framework.inputnode method)": [[0, "FrEIA.framework.InputNode.build_module"]], "build_module() (freia.framework.node method)": [[0, "FrEIA.framework.Node.build_module"]], "build_module() (freia.framework.outputnode method)": [[0, "FrEIA.framework.OutputNode.build_module"]], "get_module_by_name() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.get_module_by_name"]], "get_node_by_name() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.get_node_by_name"]], "log_jacobian_numerical() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.log_jacobian_numerical"]], "module": [[0, "module-FrEIA.framework"], [1, "module-FrEIA.modules"], [2, "module-FrEIA"]], "parse_inputs() (freia.framework.node method)": [[0, "FrEIA.framework.Node.parse_inputs"]], "actnorm (class in freia.modules)": [[1, "FrEIA.modules.ActNorm"]], "affinecouplingonesided (class in freia.modules)": [[1, "FrEIA.modules.AffineCouplingOneSided"]], "allinoneblock (class in freia.modules)": [[1, "FrEIA.modules.AllInOneBlock"]], "concat (class in freia.modules)": [[1, "FrEIA.modules.Concat"]], "conditionalaffinetransform (class in freia.modules)": [[1, "FrEIA.modules.ConditionalAffineTransform"]], "fixed1x1conv (class in freia.modules)": [[1, "FrEIA.modules.Fixed1x1Conv"]], "fixedlineartransform (class in freia.modules)": [[1, "FrEIA.modules.FixedLinearTransform"]], "flatten (class in freia.modules)": [[1, "FrEIA.modules.Flatten"]], "freia.modules": [[1, "module-FrEIA.modules"]], "gincouplingblock (class in freia.modules)": [[1, "FrEIA.modules.GINCouplingBlock"]], "glowcouplingblock (class in freia.modules)": [[1, "FrEIA.modules.GLOWCouplingBlock"]], "gaussianmixturemodel (class in freia.modules)": [[1, "FrEIA.modules.GaussianMixtureModel"]], "haardownsampling (class in freia.modules)": [[1, "FrEIA.modules.HaarDownsampling"]], "haarupsampling (class in freia.modules)": [[1, "FrEIA.modules.HaarUpsampling"]], "householderperm (class in freia.modules)": [[1, "FrEIA.modules.HouseholderPerm"]], "iresnetlayer (class in freia.modules)": [[1, "FrEIA.modules.IResNetLayer"]], "irevnetdownsampling (class in freia.modules)": [[1, "FrEIA.modules.IRevNetDownsampling"]], "irevnetupsampling (class in freia.modules)": [[1, "FrEIA.modules.IRevNetUpsampling"]], "invautoact (class in freia.modules)": [[1, "FrEIA.modules.InvAutoAct"]], "invautoacttwosided (class in freia.modules)": [[1, "FrEIA.modules.InvAutoActTwoSided"]], "invautoconv2d (class in freia.modules)": [[1, "FrEIA.modules.InvAutoConv2D"]], "invautofc (class in freia.modules)": [[1, "FrEIA.modules.InvAutoFC"]], "invertiblemodule (class in freia.modules)": [[1, "FrEIA.modules.InvertibleModule"]], "invertiblesigmoid (class in freia.modules)": [[1, "FrEIA.modules.InvertibleSigmoid"]], "learnedelementwisescaling (class in freia.modules)": [[1, "FrEIA.modules.LearnedElementwiseScaling"]], "nicecouplingblock (class in freia.modules)": [[1, "FrEIA.modules.NICECouplingBlock"]], "orthogonaltransform (class in freia.modules)": [[1, "FrEIA.modules.OrthogonalTransform"]], "permuterandom (class in freia.modules)": [[1, "FrEIA.modules.PermuteRandom"]], "rnvpcouplingblock (class in freia.modules)": [[1, "FrEIA.modules.RNVPCouplingBlock"]], "reshape (class in freia.modules)": [[1, "FrEIA.modules.Reshape"]], "split (class in freia.modules)": [[1, "FrEIA.modules.Split"]], "__init__() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.__init__"]], "__init__() (freia.modules.affinecouplingonesided method)": [[1, "FrEIA.modules.AffineCouplingOneSided.__init__"]], "__init__() (freia.modules.allinoneblock method)": [[1, "FrEIA.modules.AllInOneBlock.__init__"]], "__init__() (freia.modules.concat method)": [[1, "FrEIA.modules.Concat.__init__"]], "__init__() (freia.modules.conditionalaffinetransform method)": [[1, "FrEIA.modules.ConditionalAffineTransform.__init__"]], "__init__() (freia.modules.fixed1x1conv method)": [[1, "FrEIA.modules.Fixed1x1Conv.__init__"]], "__init__() (freia.modules.fixedlineartransform method)": [[1, "FrEIA.modules.FixedLinearTransform.__init__"]], "__init__() (freia.modules.flatten method)": [[1, "FrEIA.modules.Flatten.__init__"]], "__init__() (freia.modules.gincouplingblock method)": [[1, "FrEIA.modules.GINCouplingBlock.__init__"]], "__init__() (freia.modules.glowcouplingblock method)": [[1, "FrEIA.modules.GLOWCouplingBlock.__init__"]], "__init__() (freia.modules.gaussianmixturemodel method)": [[1, "FrEIA.modules.GaussianMixtureModel.__init__"]], "__init__() (freia.modules.haardownsampling method)": [[1, "FrEIA.modules.HaarDownsampling.__init__"]], "__init__() (freia.modules.haarupsampling method)": [[1, "FrEIA.modules.HaarUpsampling.__init__"]], "__init__() (freia.modules.householderperm method)": [[1, "FrEIA.modules.HouseholderPerm.__init__"]], "__init__() (freia.modules.iresnetlayer method)": [[1, "FrEIA.modules.IResNetLayer.__init__"]], "__init__() (freia.modules.irevnetdownsampling method)": [[1, "FrEIA.modules.IRevNetDownsampling.__init__"]], "__init__() (freia.modules.irevnetupsampling method)": [[1, "FrEIA.modules.IRevNetUpsampling.__init__"]], "__init__() (freia.modules.invautoact method)": [[1, "FrEIA.modules.InvAutoAct.__init__"]], "__init__() (freia.modules.invautoacttwosided method)": [[1, "FrEIA.modules.InvAutoActTwoSided.__init__"]], "__init__() (freia.modules.invautoconv2d method)": [[1, "FrEIA.modules.InvAutoConv2D.__init__"]], "__init__() (freia.modules.invautofc method)": [[1, "FrEIA.modules.InvAutoFC.__init__"]], "__init__() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.__init__"]], "__init__() (freia.modules.invertiblesigmoid method)": [[1, "FrEIA.modules.InvertibleSigmoid.__init__"]], "__init__() (freia.modules.learnedelementwisescaling method)": [[1, "FrEIA.modules.LearnedElementwiseScaling.__init__"]], "__init__() (freia.modules.nicecouplingblock method)": [[1, "FrEIA.modules.NICECouplingBlock.__init__"]], "__init__() (freia.modules.orthogonaltransform method)": [[1, "FrEIA.modules.OrthogonalTransform.__init__"]], "__init__() (freia.modules.permuterandom method)": [[1, "FrEIA.modules.PermuteRandom.__init__"]], "__init__() (freia.modules.rnvpcouplingblock method)": [[1, "FrEIA.modules.RNVPCouplingBlock.__init__"]], "__init__() (freia.modules.reshape method)": [[1, "FrEIA.modules.Reshape.__init__"]], "__init__() (freia.modules.split method)": [[1, "FrEIA.modules.Split.__init__"]], "forward() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.forward"]], "initialize() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.initialize"]], "lipschitz_correction() (freia.modules.iresnetlayer method)": [[1, "FrEIA.modules.IResNetLayer.lipschitz_correction"]], "load_state_dict() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.load_state_dict"]], "log_jacobian() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.log_jacobian"]], "nll_loss() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.nll_loss"]], "nll_upper_bound() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.nll_upper_bound"]], "normalize_weights() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.normalize_weights"]], "output_dims() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.output_dims"]], "pick_mixture_component() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.pick_mixture_component"]], "scale (freia.modules.actnorm property)": [[1, "FrEIA.modules.ActNorm.scale"]], "freia": [[2, "module-FrEIA"]]}}) \ No newline at end of file +Search.setIndex({"docnames": ["FrEIA.framework", "FrEIA.modules", "index", "tutorial/basic_concepts", "tutorial/examples", "tutorial/examples/bayes_flow", "tutorial/examples/convolutional", "tutorial/examples/fully_connected", "tutorial/examples/inv_unet", "tutorial/examples/training_loop_cinn", "tutorial/examples/training_loop_inn", "tutorial/graph_inns", "tutorial/invertible_operations", "tutorial/quickstart", "tutorial/sequential_inns", "tutorial/tips_tricks_faq", "tutorial/tutorial"], "filenames": ["FrEIA.framework.rst", "FrEIA.modules.rst", "index.rst", "tutorial/basic_concepts.rst", "tutorial/examples.rst", "tutorial/examples/bayes_flow.rst", "tutorial/examples/convolutional.rst", "tutorial/examples/fully_connected.rst", "tutorial/examples/inv_unet.rst", "tutorial/examples/training_loop_cinn.rst", "tutorial/examples/training_loop_inn.rst", "tutorial/graph_inns.rst", "tutorial/invertible_operations.rst", "tutorial/quickstart.rst", "tutorial/sequential_inns.rst", "tutorial/tips_tricks_faq.rst", "tutorial/tutorial.rst"], "titles": ["FrEIA.framework package", "FrEIA.modules package", "Welcome to FrEIA\u2019s documentation!", "Basic concepts", "Examples", "Bayes-flow", "Convolutional INN with invertible downsampling", "Small fully-connected INNs", "Invertible U-Net", "Training: MNIST conditional normalizing flow", "Training: CelebA normalizing flow", "Computation graph API", "Invertible Operations", "Quickstart guide", "Sequential API", "Tips & Tricks, FAQ", "Tutorial"], "terms": {"The": [0, 1, 3, 7, 11, 12, 14], "contain": [0, 1, 6, 12, 14, 15], "logic": 0, "us": [0, 1, 3, 6, 7, 11, 12, 13, 14, 15], "build": [0, 11, 12], "graph": [0, 2, 3, 6, 7, 12, 14, 16], "infer": [0, 1, 3], "order": [0, 1, 3, 11, 15], "node": [0, 1, 2, 3, 6, 11, 12], "have": [0, 1, 3, 6, 11, 12, 14, 15], "execut": 0, "forward": [0, 1, 2, 3, 11, 12, 14, 15], "backward": [0, 1, 12, 13], "direct": [0, 1, 3, 11, 13], "class": [0, 1, 11, 12, 14], "conditionnod": [0, 2, 11, 12], "dim": [0, 1, 6], "int": [0, 1], "name": [0, 1, 6, 11, 12], "none": [0, 1, 11], "sourc": [0, 1], "base": [0, 1, 2, 12], "special": [0, 12], "type": [0, 1], "repres": [0, 1, 11], "contit": 0, "input": [0, 1, 3, 6, 7, 11, 12, 14, 15], "intern": [0, 1], "network": [0, 1, 2, 7, 11, 13, 15], "insid": [0, 1, 13], "coupl": [0, 2, 7, 11, 13, 14, 15, 16], "layer": [0, 1, 15], "__init__": [0, 1, 2, 12], "build_modul": [0, 2], "condition_shap": 0, "input_shap": 0, "tupl": [0, 1, 11, 12, 14], "list": [0, 1, 11, 12, 14], "instanti": 0, "determin": [0, 1, 12, 13, 14], "output": [0, 1, 3, 6, 11, 12, 14], "dimens": [0, 1, 2, 4, 6, 11, 12, 14], "call": [0, 1, 14], "invertiblemodul": [0, 1, 2, 12], "output_dim": [0, 1, 2, 12], "graphinn": [0, 1, 2, 6, 11, 12, 14], "node_list": 0, "force_tuple_output": 0, "fals": [0, 1, 12], "verbos": 0, "thi": [0, 1, 3, 6, 11, 12, 13, 14, 15], "invert": [0, 2, 3, 4, 11, 13, 14, 16], "net": [0, 2, 4, 12], "itself": [0, 1, 11, 12], "It": [0, 1, 7, 11], "i": [0, 1, 3, 6, 7, 11, 12, 13, 14, 15], "subclass": [0, 1, 11], "support": 0, "same": [0, 1], "method": [0, 1, 3, 14], "ha": [0, 1, 7, 11, 15], "an": [0, 1, 3, 11, 12, 13, 14], "addit": [0, 1, 12], "option": [0, 1, 12], "rev": [0, 1, 11, 12, 13, 14], "which": [0, 1, 6, 7, 11, 12, 13, 14], "can": [0, 1, 3, 6, 7, 11, 12, 14, 15], "comput": [0, 1, 2, 3, 6, 7, 14, 16], "revers": [0, 1, 12, 13], "pass": [0, 1, 11, 12, 13], "jac": [0, 1, 12, 14], "addition": 0, "log": [0, 1, 12, 13, 14], "invers": [0, 1, 3, 11, 14], "jacobian": [0, 1, 12, 13, 14], "paramet": [0, 1, 7, 12, 13, 15], "dims_in": [0, 1, 11, 12, 13, 14], "specifi": [0, 1, 11, 14], "shape": [0, 1, 3, 7, 12], "oper": [0, 1, 2, 3, 7, 11, 13, 14, 16], "shape_x_0": [0, 1], "shape_x_1": [0, 1], "dims_c": [0, 1, 12], "condit": [0, 1, 2, 3, 4, 11, 12, 14], "get_module_by_nam": [0, 2], "return": [0, 1, 7, 11, 12, 13, 14], "first": [0, 1, 11, 12, 14, 15], "provid": [0, 11, 12, 13, 14], "get_node_by_nam": [0, 2], "log_jacobian_numer": [0, 2], "x": [0, 1, 12, 13, 14], "c": [0, 1, 11, 12, 14], "h": 0, "0": [0, 1, 6, 7, 12, 13, 14], "0001": 0, "approxim": [0, 2], "via": [0, 1], "finit": 0, "differ": [0, 1, 7], "inputnod": [0, 2, 6, 11, 12], "data": [0, 1, 11, 12, 13, 14], "whole": [0, 1, 11, 15], "when": [0, 1, 12], "run": [0, 12], "iter": [0, 1, 15], "module_typ": [0, 11], "module_arg": [0, 11], "dict": [0, 1], "object": [0, 1], "one": [0, 1, 11, 12, 14, 15], "transform": [0, 2, 6, 12, 13, 15], "arbitrari": [0, 3], "number": [0, 1, 3, 7, 11, 12, 15], "user": [0, 1], "underli": 0, "parse_input": [0, 2], "convert": 0, "canon": 0, "format": 0, "three": [0, 12], "form": [0, 1, 12], "singl": [0, 1, 6, 7, 11], "taken": [0, 1], "idx": 0, "each": [0, 1, 3, 11, 12, 14], "all": [0, 1, 6, 7, 11, 12, 14], "ar": [0, 1, 3, 6, 11, 12, 14, 15], "last": [0, 15], "outputnod": [0, 2, 6, 11, 12], "in_nod": 0, "reversiblegraphnet": [0, 2], "ind_in": 0, "ind_out": 0, "true": [0, 1, 3, 7, 11, 12, 13, 14], "reversiblesequenti": [0, 2, 13], "sequenceinn": [0, 2, 7, 12, 13, 14], "simpler": [0, 12], "than": [0, 1, 12], "onli": [0, 1, 7, 11, 12, 14, 15], "sequenti": [0, 2, 7, 11, 12, 13, 16], "seri": [0, 1, 7], "split": [0, 1, 2, 6, 11, 14], "merg": [0, 1, 11, 14], "branch": 0, "off": [0, 1, 6], "append": [0, 2, 6, 7, 12, 13, 14], "add": [0, 1, 15], "new": [0, 3], "block": [0, 2, 11, 13, 14, 15, 16], "more": [0, 1, 3, 11, 12, 13], "simpl": [0, 2, 4, 12, 13, 14], "wai": [0, 1, 12], "approach": 0, "For": [0, 1, 3, 6, 11, 13, 14, 15], "exampl": [0, 1, 2, 3, 7, 11, 12, 13, 14], "inn": [0, 1, 2, 3, 4, 11, 12, 13, 14, 15], "channel": [0, 1, 15], "dims_h": 0, "dims_w": 0, "rang": [0, 6, 7, 12, 13, 14], "n_block": 0, "allinoneblock": [0, 1, 2, 7, 12, 13, 14], "clamp": [0, 1, 6, 12], "2": [0, 1, 2, 4, 6, 11, 12, 13, 15], "permute_soft": [0, 1, 7, 13], "haardownsampl": [0, 1, 2], "so": [0, 1, 3, 6, 11, 12, 14, 15], "module_class": 0, "cond": [0, 7, 11, 12, 14], "cond_shap": [0, 7, 14], "kwarg": [0, 1], "from": [0, 1, 11, 13, 14], "index": [0, 1, 2, 11], "need": [0, 1, 3, 12], "tensor": [0, 1, 11, 12, 13], "further": [0, 6], "keyword": [0, 14], "argument": [0, 1, 11, 12, 14], "constructor": [0, 1, 11], "see": [0, 1, 7, 11, 12, 13, 14], "torch": [1, 7, 11, 12, 13, 14, 15], "nn": [1, 7, 11, 12, 13, 14, 15], "thing": 1, "compar": 1, "staticmethod": 1, "otuput_dim": 1, "nicecouplingblock": [1, 2], "rnvpcouplingblock": [1, 2], "glowcouplingblock": [1, 2, 6, 11, 12], "gincouplingblock": [1, 2], "affinecouplingonesid": [1, 2, 11], "conditionalaffinetransform": [1, 2], "irevnetdownsampl": [1, 2, 6], "irevnetupsampl": [1, 2], "haarupsampl": [1, 2], "flatten": [1, 2, 6, 14], "concat": [1, 2, 6, 11], "actnorm": [1, 2, 7, 11, 12], "iresnetlay": [1, 2], "invautoact": [1, 2], "invautoactfix": 1, "invautoacttwosid": [1, 2], "invautoconv2d": [1, 2], "invautofc": [1, 2], "learnedelementwisesc": [1, 2], "orthogonaltransform": [1, 2], "householderperm": [1, 2], "permuterandom": [1, 2, 6, 11], "fixedlineartransform": [1, 2], "fixed1x1conv": [1, 2], "invertiblesigmoid": [1, 2], "given": [1, 11, 12, 14], "instanc": [1, 11, 14], "some": [1, 3, 6, 15], "shall": 1, "its": [1, 11], "recov": 1, "appli": [1, 7], "mode": 1, "confus": 1, "pytorch": [1, 2, 3, 14], "gradient": [1, 15], "randn": [1, 11, 12, 13], "batch_siz": 1, "dim_count": 1, "condition_dim": 1, "z": [1, 12, 13, 14], "x_rev": [1, 12], "jac_rev": 1, "det": [1, 14], "j": 1, "left": 1, "frac": 1, "partial": 1, "f": [1, 6], "right": 1, "1": [1, 3, 6, 7, 11, 12, 13], "Then": 1, "allclos": 1, "x_or_z": 1, "bool": 1, "perform": [1, 3, 7, 11, 15], "default": 1, "through": [1, 6, 11, 14], "note": [1, 7, 11, 12], "implement": [1, 3, 11, 12], "must": [0, 1, 11, 12], "valid": 1, "punish": 1, "latter": 1, "recommend": [1, 11], "trivial": [1, 12], "follow": [1, 3, 6, 7, 11, 14], "convent": 1, "consist": [1, 14], "evalu": 1, "let": 1, "": [1, 7, 12], "make": [1, 6, 12], "precis": 1, "function": [1, 12], "ani": [1, 3, 11, 14], "arrai": 1, "like": [1, 11], "associ": 1, "log_jacobian": [1, 2], "arg": 1, "deprec": 1, "doe": [1, 3, 12], "noth": [1, 12], "except": [1, 11], "rais": 1, "warn": 1, "input_dim": [1, 12, 14], "dure": [1, 3], "construct": [1, 2, 11, 12, 14], "A": [1, 11, 12, 14], "entri": [1, 14], "even": [1, 3, 6, 12], "give": [1, 11, 14], "exclud": 1, "batch": [1, 11], "receiv": [1, 14], "32x32": [1, 11], "pixel": 1, "rgb": [1, 11], "imag": [1, 6, 11, 14], "would": [1, 11, 12, 14], "3": [1, 6, 7, 11, 12, 14, 15], "32": [1, 6, 11, 14], "structur": [1, 2, 7, 13], "half": 1, "valu": 1, "should": [1, 7, 11, 12, 14, 15], "16": 1, "up": [1, 12], "implementor": 1, "ensur": 1, "total": 1, "element": [1, 12], "subnet_constructor": [1, 6, 7, 11, 12, 13, 14], "callabl": 1, "affine_clamp": 1, "float": [1, 12], "gin_block": 1, "global_affine_init": 1, "global_affine_typ": 1, "str": [0, 1], "softplu": 1, "learned_householder_permut": 1, "reverse_permut": 1, "combin": [1, 7], "most": [1, 7, 12], "common": 1, "normal": [1, 2, 4, 13], "flow": [1, 2, 4], "similar": 1, "model": [1, 13], "affin": [1, 7, 11, 12, 13], "permut": [1, 7, 11, 12, 15], "global": [1, 12], "also": [1, 3, 11, 12], "gin": 1, "household": 1, "pre": 1, "includ": [1, 12, 13], "soft": 1, "mechan": 1, "real": 1, "nvp": 1, "y": 1, "v": 1, "r": 1, "psi": 1, "s_": 1, "mathrm": 1, "odot": 1, "big": 1, "t_": 1, "e": [1, 3, 6, 11, 12, 15], "below": [1, 11, 12, 13, 14], "reflect": 1, "matrix": 1, "togeth": [1, 7, 12, 14], "x_1": 1, "x_2": 1, "along": [1, 3, 11, 12], "two": [1, 12], "halv": 1, "u": [1, 2, 4], "u_1": 1, "u_2": 1, "exp": [1, 12], "alpha": 1, "tanh": 1, "t": [1, 12], "becaus": [1, 6, 12], "prevent": [1, 15], "explod": 1, "exponenti": [1, 12], "hyperparamet": [1, 12], "adjust": 1, "channels_in": 1, "channels_out": 1, "predict": [1, 12], "coeffici": [1, 12], "multipl": [1, 3, 6, 12, 14], "befor": 1, "abov": [1, 11, 14], "turn": 1, "sorrenson": 1, "et": 1, "al": 1, "2019": 1, "volum": 1, "preserv": 1, "initi": [1, 2, 3, 11, 12, 15], "scale": [1, 2, 12], "sigmoid": 1, "defin": [1, 2, 3, 7, 11, 13, 16], "activ": 1, "beta": 1, "whether": [1, 3], "sampl": [1, 13], "n": [1, 3], "hard": [1, 15], "instead": [1, 11, 12, 15], "veri": [1, 13, 15], "slow": 1, "work": [1, 3, 6, 12, 14, 15], "512": [1, 7, 13], "larg": 1, "dubiou": 1, "actual": 1, "help": 1, "introduc": 1, "putzki": 1, "split_len": 1, "5": [1, 11, 13, 14], "nice": [1, 12], "dinh": 1, "2015": 1, "design": [1, 11], "2d": [1, 7], "3d": 1, "4d": 1, "residu": 1, "subnetwork": [1, 12, 14, 15], "ad": [1, 14], "docstr": 1, "factori": 1, "signatur": 1, "dims_out": [1, 11, 12, 13, 14], "result": [1, 3, 11], "take": [1, 11, 12], "tutori": [1, 13], "clamp_activ": 1, "atan": [1, 12], "realnvp": [1, 3], "2017": 1, "minor": 1, "checkerboard": 1, "prepend": 1, "i_revnet_downsampl": 1, "both": [1, 3], "four": 1, "compon": [1, 12], "amplif": 1, "attenu": 1, "string": 1, "recogn": 1, "behav": 1, "origin": [1, 11, 14], "paper": [1, 11], "custom": [1, 2, 13, 16], "map": 1, "inf": 1, "glow": [1, 11, 12], "part": [1, 3, 6], "1x1": [1, 15], "convolut": [1, 2, 4, 7, 15], "etc": [1, 3, 11, 14], "onc": 1, "jointli": 1, "s_i": 1, "t_i": 1, "separ": 1, "reduc": 1, "cost": 1, "speed": 1, "constrain": 1, "achiev": 1, "subtract": 1, "mean": [1, 13, 14, 15], "while": 1, "still": 1, "power": 1, "within": [1, 3], "slightli": [1, 12], "publish": 1, "final": 1, "sum": [1, 12, 13], "zero": [1, 15], "There": 1, "wa": 1, "found": [1, 11], "between": [1, 12, 14, 15], "practic": 1, "guarante": 1, "might": 1, "stabl": [1, 12], "certain": 1, "case": [1, 12, 14], "In": [1, 12, 14], "where": 1, "random": [1, 11, 12], "orthogon": [1, 15], "after": [1, 6], "everi": [1, 12, 15], "restrict": 1, "simplifi": 1, "One": 1, "spade": 1, "park": 1, "legacy_backend": 1, "spatial": 1, "downsampl": [1, 2, 4], "revnet": 1, "group": 1, "neighbor": 1, "reorder": 1, "time": [1, 12], "pattern": 1, "jacobsen": 1, "2018": 1, "If": [1, 11, 14, 15], "concaten": [1, 3, 6], "adapt": 1, "github": [1, 11, 12], "com": 1, "jhjacobsen": 1, "blob": 1, "master": 1, "model_util": 1, "py": 1, "usual": [1, 12], "slower": 1, "gpu": 1, "stride": 1, "kernel": 1, "patch": 1, "a1": 1, "b1": 1, "a2": 1, "b2": 1, "c1": 1, "c2": 1, "a3": 1, "b3": 1, "order_by_wavelet": 1, "gener": [0, 1, 3, 7, 11, 12, 14], "complet": [1, 15], "irrelev": [1, 11], "unless": 1, "certaint": 1, "subset": 1, "suppos": 1, "extract": 1, "detail": [1, 6, 12, 13], "transpos": 1, "expect": 1, "rebal": 1, "haar": 1, "wavelet": 1, "4": [1, 6, 15], "width": 1, "height": 1, "averag": 1, "vertic": 1, "horizont": 1, "diagon": 1, "v1": 1, "h1": 1, "d1": 1, "those": 1, "v2": 1, "h2": 1, "d2": 1, "set": [1, 11], "g": [1, 3, 6, 11, 15], "allow": [1, 3], "quarter": 1, "isol": 1, "exist": [0, 1, 3], "how": [1, 11], "multipli": [1, 12], "factor": 1, "accordingli": 1, "stabil": [1, 15], "mai": [1, 15], "increas": 1, "been": 1, "concatent": 1, "higher": [1, 6], "frequenc": 1, "d": [1, 12], "target_dim": 1, "target": 1, "12": [1, 6, 7, 15], "necessarili": [1, 12], "sensibl": 1, "meaning": 1, "sequenc": [1, 14], "section_s": [1, 6], "n_section": 1, "incom": 1, "correspond": 1, "init": 1, "attribut": 1, "describ": 1, "check": [1, 11], "size": [1, 3, 6, 12], "dimension": [1, 12], "compat": 1, "handl": 1, "automat": 1, "setup": 1, "preced": [1, 11], "over": [1, 11], "section": [1, 11], "doesn": 1, "creat": 1, "slack": 1, "equal": 1, "close": 1, "numpi": 1, "array_split": 1, "count": 1, "ident": [1, 11, 12, 15], "sens": 1, "init_data": 1, "techniqu": 1, "kingma": 1, "http": [1, 11, 12], "arxiv": 1, "org": 1, "ab": [1, 11, 14], "1807": 1, "03039": 1, "tradit": 1, "standard": [1, 13, 15], "deviat": 1, "thei": [1, 3, 11, 12, 15], "treat": 1, "learnabl": 1, "interspers": 1, "throughout": 1, "intermedi": [1, 3], "train": [1, 2, 12, 13, 15], "start": [1, 12, 15], "avoid": 1, "just": [1, 6, 11, 14], "wise": 1, "bia": 1, "load_state_dict": [1, 2], "state_dict": [1, 3, 14], "strict": 1, "copi": [1, 12], "buffer": 1, "descend": 1, "kei": 1, "exactli": 1, "match": 1, "persist": 1, "strictli": 1, "enforc": 1, "missing_kei": 1, "miss": 1, "unexpected_kei": 1, "unexpect": 1, "namedtupl": 1, "field": 1, "properti": [1, 11], "slope_init": 1, "nonlinear": 1, "analog": 1, "leaki": 1, "relu": [1, 7, 11, 12, 13, 14], "slope": 1, "symmetr": 1, "posit": [1, 12], "neg": [1, 13, 14], "side": 1, "geq": 1, "impli": 1, "oslash": 1, "intput": 1, "dimenison": 1, "account": 1, "init_po": 1, "init_neg": 1, "space": 1, "stai": 1, "alpha_": 1, "init_scal": 1, "unlik": 1, "realli": 1, "individu": [1, 15], "To": [1, 3, 11, 13], "correction_interv": 1, "256": [1, 7, 11], "term": [1, 12], "free": 1, "weight": [1, 3, 13, 15], "project": 1, "back": [1, 11], "stiefel": 1, "manifold": 1, "matric": 1, "regular": 1, "interv": 1, "With": 1, "rx": 1, "b": 1, "cdot": 1, "pi": [1, 12], "mani": [1, 6, 11, 14, 15], "step": [1, 13], "perfectli": [1, 15], "n_reflect": 1, "fast": 1, "product": 1, "mathiesen": 1, "2020": 1, "invertibleworkshop": 1, "io": [1, 11, 12], "accepted_pap": 1, "pdf": [0, 1], "10": [1, 7, 11, 12, 14], "1d": [1, 11], "without": [1, 3, 12, 14], "vector": [1, 6, 11], "conatin": 1, "backpropag": [1, 13], "subsequ": 1, "independ": 1, "due": 1, "reason": 1, "randomli": 1, "kept": 1, "seed": [1, 6, 11], "multi": 1, "dimenion": [1, 6, 15], "rng": 1, "do": [1, 12, 15], "rese": 1, "m": 1, "linear": [1, 7, 11, 12, 13, 14], "tesor": 1, "mx": 1, "offset": 1, "length": 1, "squar": 1, "effect": [1, 12], "muplitpl": 1, "across": 1, "trainabl": 1, "fulli": [1, 2, 4, 6, 12, 15], "connect": [1, 2, 3, 4, 6, 12, 15], "autoencod": 1, "1802": 1, "06869": 1, "tranpos": 1, "reconstruct": 1, "loss": [1, 13], "converg": 1, "howev": 1, "becuas": [1, 12], "invauto": 1, "asymptot": 1, "limit": [1, 12], "ouput": 1, "integ": 1, "kernel_s": 1, "pad": [1, 7], "variant": 1, "convlut": 1, "choos": 1, "retain": 1, "therefor": [1, 11, 12], "respons": 1, "depend": [1, 3, 11, 12], "internal_s": 1, "n_internal_lay": 1, "jacobian_iter": 1, "20": [1, 11], "hutchinson_sampl": 1, "fixed_point_iter": 1, "50": 1, "lipschitz_iter": 1, "lipschitz_batchs": 1, "spectral_norm_max": 1, "8": [1, 7, 13, 14], "resnet": 1, "architectur": [1, 2, 3, 6, 12, 14], "propos": 1, "1811": 1, "00995": 1, "lipschitz_correct": [1, 2], "gaussianmixturemodel": [1, 2], "gaussian": [1, 7], "mixtur": [1, 7], "covari": 1, "parameter": 1, "suppli": [1, 14], "come": 1, "extern": 1, "feed": [1, 12, 14, 15], "gmm": 1, "normalize_weight": [1, 2], "w": 1, "indic": 1, "pick_mixture_compon": [1, 2], "latent": 1, "variabl": [1, 13], "chosen": [1, 12], "k": [1, 6, 7, 13, 14], "point": 1, "code": [1, 3, 12, 13], "simultan": 1, "mathemat": 1, "deriv": [1, 11], "technic": 1, "report": 1, "densiti": 1, "full": [1, 2, 6, 13], "static": 1, "nll_loss": [1, 2], "likelihood": [1, 13], "n_compon": 1, "n_dim": [1, 13], "nll_upper_bound": [1, 2], "numer": 1, "upper": 1, "bound": 1, "softmax": 1, "unnorm": 1, "probabl": 1, "decis": 1, "framework": [2, 3, 7, 11, 13, 14], "easili": [2, 7], "modul": [2, 7, 11, 12, 13, 14], "quickstart": 2, "guid": 2, "basic": [2, 11, 12, 13, 16], "concept": [2, 11, 16], "api": [2, 7, 16], "tip": [2, 16], "trick": [2, 16], "faq": [2, 16], "small": [2, 4], "mnist": [2, 4, 14], "bay": [2, 4], "loop": [2, 11, 13], "celeba": [2, 4], "content": [2, 6], "abstract": 2, "templat": [2, 12], "reshap": [2, 15], "topologi": 2, "other": [2, 3, 11, 12, 15], "learn": [2, 7, 12, 13], "fix": [2, 3, 12], "non": 2, "semi": 2, "search": 2, "page": [2, 14], "why": 3, "freia": [3, 7, 11, 12, 13, 14], "100": [3, 11, 13], "line": 3, "That": 3, "correct": 3, "long": 3, "loos": 3, "end": 3, "entir": 3, "we": [3, 6, 7, 11, 12, 13, 14], "consisit": 3, "consid": 3, "want": [3, 11], "complic": [3, 12], "skip": [3, 6], "effici": [3, 7], "prototyp": 3, "experiment": 3, "task": 3, "As": 3, "mind": [3, 15], "requir": [3, 12], "keep": [3, 15], "track": 3, "edg": 3, "store": [0, 3], "them": [3, 6, 12], "until": 3, "tool": 3, "cuda": [3, 14], "dataparallel": 3, "worri": 3, "correctli": 3, "interfac": 3, "main": 3, "address": 3, "cifar10": 6, "encod": 6, "local": 6, "rest": 6, "semant": 6, "import": [6, 7, 13, 14, 15], "moder": 6, "becom": 6, "infeas": 6, "depth": 6, "enocd": 6, "nois": [6, 13, 15], "earli": [6, 15], "machineri": 6, "ff": [6, 7, 11, 12, 13, 14], "ndim_x": 6, "resolut": 6, "fm": [6, 7, 11, 12, 13, 14], "subnet_conv": [6, 7], "conv_high_res_": 6, "permute_high_res_": 6, "lower": 6, "subnet": [6, 11, 13, 15], "subnet_conv_1x1": [6, 7], "els": [6, 12, 15], "conv_low_res_": 6, "permute_low_res_": 6, "split_nod": 6, "subnet_fc": [6, 7, 13, 14], "fully_connected_": 6, "permute_": 6, "get": [6, 11, 13, 15], "out0": [6, 11, 12], "out1": [6, 11, 12], "concat1d": 6, "conv_inn": 6, "These": [7, 12], "declar": 7, "def": [7, 11, 12, 13, 14], "c_in": 7, "c_out": 7, "conv2d": 7, "abl": 7, "distribut": [7, 13], "visual": 7, "sinc": 7, "chain": [7, 13], "cinn": [7, 11, 14], "quit": 7, "well": [7, 13, 15], "particularli": 7, "respect": 7, "again": [7, 14], "collect": [7, 13], "28": 7, "specif": 11, "in1": [11, 12], "omit": 11, "principl": 11, "appear": 11, "error": 11, "messag": 11, "out": [11, 12, 14], "equival": 11, "you": [11, 15], "cover": 11, "later": 11, "particular": [11, 15], "what": 11, "look": 11, "perm": 11, "Or": 11, "merge2": 11, "split2": 11, "merge1": 11, "whose": 11, "sever": 11, "featur": [11, 12], "conveni": 11, "directli": [11, 12], "sinlg": 11, "far": [11, 14], "shown": 11, "closer": 11, "discuss": [11, 12], "document": [11, 12], "under": 11, "vll": [11, 12], "hd": [11, 12], "html": [11, 12], "dictionari": 11, "accept": 11, "could": [11, 12], "111": 11, "empti": 11, "rule": 11, "in2": 11, "42": 11, "split1": 11, "concat1": 11, "concat2": 11, "output1": 11, "output2": 11, "example_inn": 11, "dummi": 11, "x1": [11, 12], "x2": [11, 12], "z1": [11, 12], "z2": [11, 12], "log_jac_det": [11, 12, 13], "x1_inv": 11, "x2_inv": 11, "log_jac_det_inv": 11, "assert": [11, 14], "max": [11, 14], "1e": [11, 14, 15], "commonli": 12, "submodul": 12, "here": [12, 14], "_build": 12, "rnvp": 12, "merit": 12, "smaller": 12, "known": [12, 14], "perhap": 12, "hand": 12, "anew": 12, "modifi": 12, "rather": 12, "lot": 12, "sound": 12, "fc_constr": 12, "128": [12, 14, 15], "lead": 12, "much": 12, "enabl": 12, "larger": 12, "rate": 12, "therebi": 12, "good": 12, "place": 12, "variou": 12, "save": 12, "effort": 12, "written": 12, "extens": 12, "refer": [12, 14], "inform": [12, 13], "illustr": 12, "definit": 12, "either": 12, "second": 12, "swap": 12, "otherwis": 12, "calcul": [12, 13], "fixedrandomelementwisemultipli": 12, "self": 12, "super": 12, "random_factor": 12, "randint": 12, "conditionalswap": 12, "don": 12, "x1_new": 12, "x2_new": 12, "usag": [12, 13], "batchsiz": [12, 13], "log_jac_det_rev": 12, "input_1": 12, "input_2": 12, "mult_1": 12, "cond_swap": 12, "conditional_swap": 12, "mult_2": 12, "output_1": 12, "output_2": 12, "x1_rev": 12, "x2_rev": 12, "_": [12, 13], "jump": 13, "straight": 13, "moon": 13, "dataset": 13, "sklearn": 13, "make_moon": 13, "optim": 13, "adam": [13, 15], "lr": 13, "001": 13, "1000": 13, "zero_grad": 13, "label": [13, 14], "n_sampl": 13, "05": 13, "prior": 13, "updat": 13, "explicitli": 14, "write": 14, "784": 14, "cifar": 14, "simpli": 14, "obtain": 14, "x_inv": 14, "jac_inv": 14, "child": 14, "presuppos": 14, "next": 14, "now": 14, "imagin": 14, "hot": 14, "cond_dim": 14, "tell": 14, "one_hot_label": 14, "stochast": 15, "descent": 15, "clip": 15, "experienc": 15, "instabl": 15, "util": 15, "clip_grad_norm_": 15, "slight": 15, "spars": 15, "quantiz": 15, "correl": 15, "xavier": 15, "unstabl": 15, "your": 15, "deep": 15, "30": 15, "nan": 15, "forget": 15, "shallow": 15, "wide": 15, "neuron": 15, "64": 15, "conv": 15, "think": 15, "huge": 15, "being": 15, "said": 15, "roughli": 15, "too": 15, "break": 15, "oppos": 15, "kind": 15, "seem": 15, "qualiti": 15, "constitut": 15, "third": 15, "rationalquadraticsplin": 1, "elementwiserationalquadraticsplin": 1, "plot": [0, 2], "path": 0, "filenam": 0, "dot": 0, "file": 0, "directori": 0, "previou": 0, "newli": 0}, "objects": {"": [[2, 0, 0, "-", "FrEIA"]], "FrEIA": [[0, 0, 0, "-", "framework"], [1, 0, 0, "-", "modules"]], "FrEIA.framework": [[0, 1, 1, "", "ConditionNode"], [0, 1, 1, "", "GraphINN"], [0, 1, 1, "", "InputNode"], [0, 1, 1, "", "Node"], [0, 1, 1, "", "OutputNode"], [0, 1, 1, "", "ReversibleGraphNet"], [0, 1, 1, "", "ReversibleSequential"], [0, 1, 1, "", "SequenceINN"]], "FrEIA.framework.ConditionNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.GraphINN": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "get_module_by_name"], [0, 2, 1, "", "get_node_by_name"], [0, 2, 1, "", "log_jacobian_numerical"], [0, 2, 1, "", "plot"]], "FrEIA.framework.InputNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.Node": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"], [0, 2, 1, "", "parse_inputs"]], "FrEIA.framework.OutputNode": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "build_module"]], "FrEIA.framework.ReversibleGraphNet": [[0, 2, 1, "", "__init__"]], "FrEIA.framework.ReversibleSequential": [[0, 2, 1, "", "__init__"]], "FrEIA.framework.SequenceINN": [[0, 2, 1, "", "__init__"], [0, 2, 1, "", "append"]], "FrEIA.modules": [[1, 1, 1, "", "ActNorm"], [1, 1, 1, "", "AffineCouplingOneSided"], [1, 1, 1, "", "AllInOneBlock"], [1, 1, 1, "", "Concat"], [1, 1, 1, "", "ConditionalAffineTransform"], [1, 1, 1, "", "Fixed1x1Conv"], [1, 1, 1, "", "FixedLinearTransform"], [1, 1, 1, "", "Flatten"], [1, 1, 1, "", "GINCouplingBlock"], [1, 1, 1, "", "GLOWCouplingBlock"], [1, 1, 1, "", "GaussianMixtureModel"], [1, 1, 1, "", "HaarDownsampling"], [1, 1, 1, "", "HaarUpsampling"], [1, 1, 1, "", "HouseholderPerm"], [1, 1, 1, "", "IResNetLayer"], [1, 1, 1, "", "IRevNetDownsampling"], [1, 1, 1, "", "IRevNetUpsampling"], [1, 1, 1, "", "InvAutoAct"], [1, 1, 1, "", "InvAutoActTwoSided"], [1, 1, 1, "", "InvAutoConv2D"], [1, 1, 1, "", "InvAutoFC"], [1, 1, 1, "", "InvertibleModule"], [1, 1, 1, "", "InvertibleSigmoid"], [1, 1, 1, "", "LearnedElementwiseScaling"], [1, 1, 1, "", "NICECouplingBlock"], [1, 1, 1, "", "OrthogonalTransform"], [1, 1, 1, "", "PermuteRandom"], [1, 1, 1, "", "RNVPCouplingBlock"], [1, 1, 1, "", "Reshape"], [1, 1, 1, "", "Split"]], "FrEIA.modules.ActNorm": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "initialize"], [1, 2, 1, "", "load_state_dict"], [1, 3, 1, "", "scale"]], "FrEIA.modules.AffineCouplingOneSided": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.AllInOneBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Concat": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.ConditionalAffineTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Fixed1x1Conv": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.FixedLinearTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Flatten": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GINCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GLOWCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.GaussianMixtureModel": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "nll_loss"], [1, 2, 1, "", "nll_upper_bound"], [1, 2, 1, "", "normalize_weights"], [1, 2, 1, "", "pick_mixture_component"]], "FrEIA.modules.HaarDownsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.HaarUpsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.HouseholderPerm": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.IResNetLayer": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "lipschitz_correction"]], "FrEIA.modules.IRevNetDownsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.IRevNetUpsampling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoAct": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoActTwoSided": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoConv2D": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvAutoFC": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.InvertibleModule": [[1, 2, 1, "", "__init__"], [1, 2, 1, "", "forward"], [1, 2, 1, "", "log_jacobian"], [1, 2, 1, "", "output_dims"]], "FrEIA.modules.InvertibleSigmoid": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.LearnedElementwiseScaling": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.NICECouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.OrthogonalTransform": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.PermuteRandom": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.RNVPCouplingBlock": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Reshape": [[1, 2, 1, "", "__init__"]], "FrEIA.modules.Split": [[1, 2, 1, "", "__init__"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:property"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "property", "Python property"]}, "titleterms": {"freia": [0, 1, 2], "framework": 0, "packag": [0, 1, 2], "modul": [0, 1], "content": 0, "abstract": 1, "templat": 1, "coupl": [1, 12], "block": [1, 12], "reshap": 1, "graph": [1, 11], "topologi": 1, "other": 1, "learn": 1, "transform": 1, "fix": 1, "non": 1, "approxim": 1, "semi": 1, "invert": [1, 6, 8, 12], "welcom": 2, "": 2, "document": 2, "tutori": [2, 16], "indic": 2, "tabl": 2, "basic": 3, "concept": 3, "exampl": 4, "network": 4, "architectur": 4, "full": 4, "train": [4, 9, 10], "loop": 4, "bay": 5, "flow": [5, 9, 10], "convolut": 6, "inn": [6, 7], "downsampl": 6, "small": 7, "fulli": 7, "connect": 7, "simpl": 7, "2": 7, "dimens": 7, "condit": [7, 9], "mnist": [7, 9], "u": 8, "net": 8, "normal": [9, 10], "celeba": 10, "comput": 11, "api": [11, 14], "oper": 12, "defin": 12, "custom": 12, "quickstart": 13, "guid": 13, "sequenti": 14, "tip": 15, "trick": 15, "faq": 15}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"FrEIA.framework package": [[0, "freia-framework-package"]], "Module contents": [[0, "module-FrEIA.framework"]], "FrEIA.modules package": [[1, "module-FrEIA.modules"]], "Abstract template": [[1, "abstract-template"]], "Coupling blocks": [[1, "coupling-blocks"], [12, "coupling-blocks"]], "Reshaping": [[1, "reshaping"]], "Graph topology": [[1, "graph-topology"]], "Other learned transforms": [[1, "other-learned-transforms"]], "Fixed (non-learned) transforms": [[1, "fixed-non-learned-transforms"]], "Approximately- or semi-invertible transforms": [[1, "approximately-or-semi-invertible-transforms"]], "Welcome to FrEIA\u2019s documentation!": [[2, "module-FrEIA"]], "Tutorial": [[2, "tutorial"], [16, "tutorial"]], "Package Documentation": [[2, "package-documentation"]], "Indices and tables": [[2, "indices-and-tables"]], "Basic concepts": [[3, "basic-concepts"]], "Examples": [[4, "examples"]], "Network architectures": [[4, "network-architectures"]], "Full training loops": [[4, "full-training-loops"]], "Bayes-flow": [[5, "bayes-flow"]], "Convolutional INN with invertible downsampling": [[6, "convolutional-inn-with-invertible-downsampling"]], "Small fully-connected INNs": [[7, "small-fully-connected-inns"]], "Simple INN in 2 dimensions": [[7, "simple-inn-in-2-dimensions"]], "Conditional INN for MNIST": [[7, "conditional-inn-for-mnist"]], "Invertible U-Net": [[8, "invertible-u-net"]], "Training: MNIST conditional normalizing flow": [[9, "training-mnist-conditional-normalizing-flow"]], "Training: CelebA normalizing flow": [[10, "training-celeba-normalizing-flow"]], "Computation graph API": [[11, "computation-graph-api"]], "Invertible Operations": [[12, "invertible-operations"]], "Defining custom invertible operations": [[12, "defining-custom-invertible-operations"]], "Quickstart guide": [[13, "quickstart-guide"]], "Sequential API": [[14, "sequential-api"]], "Tips & Tricks, FAQ": [[15, "tips-tricks-faq"]]}, "indexentries": {"conditionnode (class in freia.framework)": [[0, "FrEIA.framework.ConditionNode"]], "freia.framework": [[0, "module-FrEIA.framework"]], "graphinn (class in freia.framework)": [[0, "FrEIA.framework.GraphINN"]], "inputnode (class in freia.framework)": [[0, "FrEIA.framework.InputNode"]], "node (class in freia.framework)": [[0, "FrEIA.framework.Node"]], "outputnode (class in freia.framework)": [[0, "FrEIA.framework.OutputNode"]], "reversiblegraphnet (class in freia.framework)": [[0, "FrEIA.framework.ReversibleGraphNet"]], "reversiblesequential (class in freia.framework)": [[0, "FrEIA.framework.ReversibleSequential"]], "sequenceinn (class in freia.framework)": [[0, "FrEIA.framework.SequenceINN"]], "__init__() (freia.framework.conditionnode method)": [[0, "FrEIA.framework.ConditionNode.__init__"]], "__init__() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.__init__"]], "__init__() (freia.framework.inputnode method)": [[0, "FrEIA.framework.InputNode.__init__"]], "__init__() (freia.framework.node method)": [[0, "FrEIA.framework.Node.__init__"]], "__init__() (freia.framework.outputnode method)": [[0, "FrEIA.framework.OutputNode.__init__"]], "__init__() (freia.framework.reversiblegraphnet method)": [[0, "FrEIA.framework.ReversibleGraphNet.__init__"]], "__init__() (freia.framework.reversiblesequential method)": [[0, "FrEIA.framework.ReversibleSequential.__init__"]], "__init__() (freia.framework.sequenceinn method)": [[0, "FrEIA.framework.SequenceINN.__init__"]], "append() (freia.framework.sequenceinn method)": [[0, "FrEIA.framework.SequenceINN.append"]], "build_module() (freia.framework.conditionnode method)": [[0, "FrEIA.framework.ConditionNode.build_module"]], "build_module() (freia.framework.inputnode method)": [[0, "FrEIA.framework.InputNode.build_module"]], "build_module() (freia.framework.node method)": [[0, "FrEIA.framework.Node.build_module"]], "build_module() (freia.framework.outputnode method)": [[0, "FrEIA.framework.OutputNode.build_module"]], "get_module_by_name() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.get_module_by_name"]], "get_node_by_name() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.get_node_by_name"]], "log_jacobian_numerical() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.log_jacobian_numerical"]], "module": [[0, "module-FrEIA.framework"], [1, "module-FrEIA.modules"], [2, "module-FrEIA"]], "parse_inputs() (freia.framework.node method)": [[0, "FrEIA.framework.Node.parse_inputs"]], "plot() (freia.framework.graphinn method)": [[0, "FrEIA.framework.GraphINN.plot"]], "actnorm (class in freia.modules)": [[1, "FrEIA.modules.ActNorm"]], "affinecouplingonesided (class in freia.modules)": [[1, "FrEIA.modules.AffineCouplingOneSided"]], "allinoneblock (class in freia.modules)": [[1, "FrEIA.modules.AllInOneBlock"]], "concat (class in freia.modules)": [[1, "FrEIA.modules.Concat"]], "conditionalaffinetransform (class in freia.modules)": [[1, "FrEIA.modules.ConditionalAffineTransform"]], "fixed1x1conv (class in freia.modules)": [[1, "FrEIA.modules.Fixed1x1Conv"]], "fixedlineartransform (class in freia.modules)": [[1, "FrEIA.modules.FixedLinearTransform"]], "flatten (class in freia.modules)": [[1, "FrEIA.modules.Flatten"]], "freia.modules": [[1, "module-FrEIA.modules"]], "gincouplingblock (class in freia.modules)": [[1, "FrEIA.modules.GINCouplingBlock"]], "glowcouplingblock (class in freia.modules)": [[1, "FrEIA.modules.GLOWCouplingBlock"]], "gaussianmixturemodel (class in freia.modules)": [[1, "FrEIA.modules.GaussianMixtureModel"]], "haardownsampling (class in freia.modules)": [[1, "FrEIA.modules.HaarDownsampling"]], "haarupsampling (class in freia.modules)": [[1, "FrEIA.modules.HaarUpsampling"]], "householderperm (class in freia.modules)": [[1, "FrEIA.modules.HouseholderPerm"]], "iresnetlayer (class in freia.modules)": [[1, "FrEIA.modules.IResNetLayer"]], "irevnetdownsampling (class in freia.modules)": [[1, "FrEIA.modules.IRevNetDownsampling"]], "irevnetupsampling (class in freia.modules)": [[1, "FrEIA.modules.IRevNetUpsampling"]], "invautoact (class in freia.modules)": [[1, "FrEIA.modules.InvAutoAct"]], "invautoacttwosided (class in freia.modules)": [[1, "FrEIA.modules.InvAutoActTwoSided"]], "invautoconv2d (class in freia.modules)": [[1, "FrEIA.modules.InvAutoConv2D"]], "invautofc (class in freia.modules)": [[1, "FrEIA.modules.InvAutoFC"]], "invertiblemodule (class in freia.modules)": [[1, "FrEIA.modules.InvertibleModule"]], "invertiblesigmoid (class in freia.modules)": [[1, "FrEIA.modules.InvertibleSigmoid"]], "learnedelementwisescaling (class in freia.modules)": [[1, "FrEIA.modules.LearnedElementwiseScaling"]], "nicecouplingblock (class in freia.modules)": [[1, "FrEIA.modules.NICECouplingBlock"]], "orthogonaltransform (class in freia.modules)": [[1, "FrEIA.modules.OrthogonalTransform"]], "permuterandom (class in freia.modules)": [[1, "FrEIA.modules.PermuteRandom"]], "rnvpcouplingblock (class in freia.modules)": [[1, "FrEIA.modules.RNVPCouplingBlock"]], "reshape (class in freia.modules)": [[1, "FrEIA.modules.Reshape"]], "split (class in freia.modules)": [[1, "FrEIA.modules.Split"]], "__init__() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.__init__"]], "__init__() (freia.modules.affinecouplingonesided method)": [[1, "FrEIA.modules.AffineCouplingOneSided.__init__"]], "__init__() (freia.modules.allinoneblock method)": [[1, "FrEIA.modules.AllInOneBlock.__init__"]], "__init__() (freia.modules.concat method)": [[1, "FrEIA.modules.Concat.__init__"]], "__init__() (freia.modules.conditionalaffinetransform method)": [[1, "FrEIA.modules.ConditionalAffineTransform.__init__"]], "__init__() (freia.modules.fixed1x1conv method)": [[1, "FrEIA.modules.Fixed1x1Conv.__init__"]], "__init__() (freia.modules.fixedlineartransform method)": [[1, "FrEIA.modules.FixedLinearTransform.__init__"]], "__init__() (freia.modules.flatten method)": [[1, "FrEIA.modules.Flatten.__init__"]], "__init__() (freia.modules.gincouplingblock method)": [[1, "FrEIA.modules.GINCouplingBlock.__init__"]], "__init__() (freia.modules.glowcouplingblock method)": [[1, "FrEIA.modules.GLOWCouplingBlock.__init__"]], "__init__() (freia.modules.gaussianmixturemodel method)": [[1, "FrEIA.modules.GaussianMixtureModel.__init__"]], "__init__() (freia.modules.haardownsampling method)": [[1, "FrEIA.modules.HaarDownsampling.__init__"]], "__init__() (freia.modules.haarupsampling method)": [[1, "FrEIA.modules.HaarUpsampling.__init__"]], "__init__() (freia.modules.householderperm method)": [[1, "FrEIA.modules.HouseholderPerm.__init__"]], "__init__() (freia.modules.iresnetlayer method)": [[1, "FrEIA.modules.IResNetLayer.__init__"]], "__init__() (freia.modules.irevnetdownsampling method)": [[1, "FrEIA.modules.IRevNetDownsampling.__init__"]], "__init__() (freia.modules.irevnetupsampling method)": [[1, "FrEIA.modules.IRevNetUpsampling.__init__"]], "__init__() (freia.modules.invautoact method)": [[1, "FrEIA.modules.InvAutoAct.__init__"]], "__init__() (freia.modules.invautoacttwosided method)": [[1, "FrEIA.modules.InvAutoActTwoSided.__init__"]], "__init__() (freia.modules.invautoconv2d method)": [[1, "FrEIA.modules.InvAutoConv2D.__init__"]], "__init__() (freia.modules.invautofc method)": [[1, "FrEIA.modules.InvAutoFC.__init__"]], "__init__() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.__init__"]], "__init__() (freia.modules.invertiblesigmoid method)": [[1, "FrEIA.modules.InvertibleSigmoid.__init__"]], "__init__() (freia.modules.learnedelementwisescaling method)": [[1, "FrEIA.modules.LearnedElementwiseScaling.__init__"]], "__init__() (freia.modules.nicecouplingblock method)": [[1, "FrEIA.modules.NICECouplingBlock.__init__"]], "__init__() (freia.modules.orthogonaltransform method)": [[1, "FrEIA.modules.OrthogonalTransform.__init__"]], "__init__() (freia.modules.permuterandom method)": [[1, "FrEIA.modules.PermuteRandom.__init__"]], "__init__() (freia.modules.rnvpcouplingblock method)": [[1, "FrEIA.modules.RNVPCouplingBlock.__init__"]], "__init__() (freia.modules.reshape method)": [[1, "FrEIA.modules.Reshape.__init__"]], "__init__() (freia.modules.split method)": [[1, "FrEIA.modules.Split.__init__"]], "forward() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.forward"]], "initialize() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.initialize"]], "lipschitz_correction() (freia.modules.iresnetlayer method)": [[1, "FrEIA.modules.IResNetLayer.lipschitz_correction"]], "load_state_dict() (freia.modules.actnorm method)": [[1, "FrEIA.modules.ActNorm.load_state_dict"]], "log_jacobian() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.log_jacobian"]], "nll_loss() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.nll_loss"]], "nll_upper_bound() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.nll_upper_bound"]], "normalize_weights() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.normalize_weights"]], "output_dims() (freia.modules.invertiblemodule method)": [[1, "FrEIA.modules.InvertibleModule.output_dims"]], "pick_mixture_component() (freia.modules.gaussianmixturemodel static method)": [[1, "FrEIA.modules.GaussianMixtureModel.pick_mixture_component"]], "scale (freia.modules.actnorm property)": [[1, "FrEIA.modules.ActNorm.scale"]], "freia": [[2, "module-FrEIA"]]}}) \ No newline at end of file diff --git a/docs/_build/html/tutorial/basic_concepts.html b/docs/_build/html/tutorial/basic_concepts.html index bb9d922..dcd5ae9 100644 --- a/docs/_build/html/tutorial/basic_concepts.html +++ b/docs/_build/html/tutorial/basic_concepts.html @@ -477,7 +477,7 @@

      Basic concepts diff --git a/docs/_build/html/tutorial/examples.html b/docs/_build/html/tutorial/examples.html index 3af6b78..49eca8b 100644 --- a/docs/_build/html/tutorial/examples.html +++ b/docs/_build/html/tutorial/examples.html @@ -485,7 +485,7 @@

      Full training loops diff --git a/docs/_build/html/tutorial/examples/bayes_flow.html b/docs/_build/html/tutorial/examples/bayes_flow.html index a9e23b3..5888658 100644 --- a/docs/_build/html/tutorial/examples/bayes_flow.html +++ b/docs/_build/html/tutorial/examples/bayes_flow.html @@ -453,7 +453,7 @@

      Bayes-flow diff --git a/docs/_build/html/tutorial/examples/convolutional.html b/docs/_build/html/tutorial/examples/convolutional.html index b3bd17d..248fa7a 100644 --- a/docs/_build/html/tutorial/examples/convolutional.html +++ b/docs/_build/html/tutorial/examples/convolutional.html @@ -520,7 +520,7 @@

      Convolutional INN with invertible downsampling diff --git a/docs/_build/html/tutorial/examples/fully_connected.html b/docs/_build/html/tutorial/examples/fully_connected.html index f552bd9..1d23cd2 100644 --- a/docs/_build/html/tutorial/examples/fully_connected.html +++ b/docs/_build/html/tutorial/examples/fully_connected.html @@ -507,7 +507,7 @@

      Conditional INN for MNIST diff --git a/docs/_build/html/tutorial/examples/inv_unet.html b/docs/_build/html/tutorial/examples/inv_unet.html index 8badcf9..583f322 100644 --- a/docs/_build/html/tutorial/examples/inv_unet.html +++ b/docs/_build/html/tutorial/examples/inv_unet.html @@ -453,7 +453,7 @@

      Invertible U-Net diff --git a/docs/_build/html/tutorial/examples/training_loop_cinn.html b/docs/_build/html/tutorial/examples/training_loop_cinn.html index ac7e88e..4af48d8 100644 --- a/docs/_build/html/tutorial/examples/training_loop_cinn.html +++ b/docs/_build/html/tutorial/examples/training_loop_cinn.html @@ -453,7 +453,7 @@

      Training: MNIST conditional normalizing flow diff --git a/docs/_build/html/tutorial/examples/training_loop_inn.html b/docs/_build/html/tutorial/examples/training_loop_inn.html index f677f9f..6408107 100644 --- a/docs/_build/html/tutorial/examples/training_loop_inn.html +++ b/docs/_build/html/tutorial/examples/training_loop_inn.html @@ -453,7 +453,7 @@

      Training: CelebA normalizing flow diff --git a/docs/_build/html/tutorial/graph_inns.html b/docs/_build/html/tutorial/graph_inns.html index ac0d878..b0598ef 100644 --- a/docs/_build/html/tutorial/graph_inns.html +++ b/docs/_build/html/tutorial/graph_inns.html @@ -565,7 +565,7 @@

      Computation graph API diff --git a/docs/_build/html/tutorial/invertible_operations.html b/docs/_build/html/tutorial/invertible_operations.html index 0e80d92..91209c7 100644 --- a/docs/_build/html/tutorial/invertible_operations.html +++ b/docs/_build/html/tutorial/invertible_operations.html @@ -626,7 +626,7 @@

      Defining custom invertible operations diff --git a/docs/_build/html/tutorial/quickstart.html b/docs/_build/html/tutorial/quickstart.html index 764dbef..f07ed0b 100644 --- a/docs/_build/html/tutorial/quickstart.html +++ b/docs/_build/html/tutorial/quickstart.html @@ -482,7 +482,7 @@

      Quickstart guide diff --git a/docs/_build/html/tutorial/sequential_inns.html b/docs/_build/html/tutorial/sequential_inns.html index 7fc1c90..04ef646 100644 --- a/docs/_build/html/tutorial/sequential_inns.html +++ b/docs/_build/html/tutorial/sequential_inns.html @@ -515,7 +515,7 @@

      Sequential API diff --git a/docs/_build/html/tutorial/tips_tricks_faq.html b/docs/_build/html/tutorial/tips_tricks_faq.html index 6ac1ec4..161c027 100644 --- a/docs/_build/html/tutorial/tips_tricks_faq.html +++ b/docs/_build/html/tutorial/tips_tricks_faq.html @@ -475,7 +475,7 @@

      Tips & Tricks, FAQ diff --git a/docs/_build/html/tutorial/tutorial.html b/docs/_build/html/tutorial/tutorial.html index 1338f61..76d571e 100644 --- a/docs/_build/html/tutorial/tutorial.html +++ b/docs/_build/html/tutorial/tutorial.html @@ -460,7 +460,7 @@

      Tutorial diff --git a/requirements.txt b/requirements.txt index 278ebd5..d0806d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch>=1.0.0 numpy>=1.15.0 scipy>=1.5 +graphviz>=0.20.1 diff --git a/setup.py b/setup.py index 20f850e..528d760 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ ## dependencies packages=find_packages(), - install_requires=['numpy>=1.15.0','scipy>=1.5', 'torch>=1.0.0'], + install_requires=['numpy>=1.15.0','scipy>=1.5', 'torch>=1.0.0', 'graphviz>=0.20.1'], # extras_require={ # 'testruns': ['pytest', 'pytest-benchmark'], # }, diff --git a/tests/test_graph_inn.py b/tests/test_graph_inn.py index 715099c..db5f86e 100644 --- a/tests/test_graph_inn.py +++ b/tests/test_graph_inn.py @@ -1,10 +1,39 @@ import unittest import torch +import torch.nn as nn -from FrEIA.framework import GraphINN, InputNode, Node, OutputNode -from FrEIA.modules import AllInOneBlock +from FrEIA.framework import GraphINN, InputNode, Node, OutputNode, ConditionNode +from FrEIA.modules import AllInOneBlock, Split, Reshape, Flatten, RNVPCouplingBlock, PermuteRandom, HaarDownsampling, Concat +import os + +import graphviz + +# the reason the subnet init is needed, is that with uninitalized +# weights, the numerical jacobian check gives inf, nan, etc, +def subnet_initialization(m): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + m.weight.data *= 0.3 + m.bias.data *= 0.1 + +def F_conv(cin, cout): + '''Simple convolutional subnetwork''' + net = nn.Sequential(nn.Conv2d(cin, 32, 3, padding=1), + nn.ReLU(), + nn.Conv2d(32, cout, 3, padding=1)) + net.apply(subnet_initialization) + return net + + +def F_fully_connected(cin, cout): + '''Simple fully connected subnetwork''' + net = nn.Sequential(nn.Linear(cin, 128), + nn.ReLU(), + nn.Linear(128, cout)) + net.apply(subnet_initialization) + return net class GraphINNTest(unittest.TestCase): def test_existing_module(self): @@ -19,3 +48,133 @@ def test_existing_module(self): batch_size = 16 out, jac = graph_inn(torch.randn(batch_size, dim)) assert out.shape == (batch_size, dim) + +# def has_graphviz_backend(): +# plotdir = 'test_plots_graphviz' +# plot_name = 'graph' +# file_path = os.path.join(plotdir, plot_name) + +# in_node = InputNode(3, 10, 10) +# out_node = OutputNode(in_node) +# graph = GraphINN([in_node, out_node]) +# try: +# os.mkdir(plotdir) +# graph.plot(path=plotdir, filename=plot_name) +# except Exception: +# if os.path.exists(file_path): +# os.remove(file_path) +# if os.path.exists(file_path + ".pdf"): +# os.remove(file_path + ".pdf") +# os.rmdir(plotdir) +# return False + +# file_path = os.path.join(plotdir, plot_name) +# os.rmdir(plotdir) +# return True + +class PlotGraphINNTest(unittest.TestCase): + plotdir = os.path.join(os.getcwd(),"graphINN_test_plots") + plot_name = "graph" + file_path = os.path.join(plotdir, plot_name) + has_graphviz_backend = True + + def cleanup_files(self): + if os.path.exists(self.file_path): + os.remove(self.file_path) + if os.path.exists(self.file_path + ".pdf"): + os.remove(self.file_path + ".pdf") + + @classmethod + def setUpClass(self): + os.mkdir(self.plotdir) + + in_node = InputNode(3, 10, 10) + out_node = OutputNode(in_node) + graph = GraphINN([in_node, out_node]) + try: + graph.plot(path=self.plotdir, filename=self.plot_name) + except Exception: + self.cleanup_files(self) + self.has_graphviz_backend = False + self.cleanup_files(self) + + @classmethod + def tearDownClass(self) -> None: + os.rmdir(self.plotdir) + + def setUp(self): + if not self.has_graphviz_backend: + self.skipTest('Skipped testing graph plots since graphviz backend is not installed.') + + def tearDown(self) -> None: + self.cleanup_files() + + def test_input_output_graph(self): + in_node = InputNode(3, 10, 10) + out_node = OutputNode(in_node) + graph = GraphINN([in_node, out_node]) + graph.plot(path=self.plotdir, filename=self.plot_name) + + + self.assertTrue(os.path.exists(self.file_path)) + self.assertTrue(os.path.exists(self.file_path + ".pdf")) + + def test_raises_non_existing_path(self): + in_node = InputNode(3, 10, 10) + out_node = OutputNode(in_node) + graph = GraphINN([in_node, out_node]) + + self.assertRaises(Exception, graph.plot, "not_existing_path", self.plot_name) + + def test_one_layer_graph(self): + nodes = [] + dim = 3 + nodes.append(InputNode(dim)) + nodes.append(Node(nodes[-1], AllInOneBlock(nodes[-1].output_dims, subnet_constructor=torch.nn.Linear))) + nodes.append(OutputNode(nodes[-1])) + graph = GraphINN(nodes) + graph.plot(path=self.plotdir, filename=self.plot_name) + + self.assertTrue(os.path.exists(self.file_path)) + self.assertTrue(os.path.exists(self.file_path + ".pdf")) + + def test_complex_graph(self): + inp_size = (3, 10, 10) + cond_size = (1, 10, 10) + + inp = InputNode(*inp_size, name='input') + cond = ConditionNode(*cond_size, name='cond') + split = Node(inp, Split, {'section_sizes': [1,2], 'dim': 0}, name='split1') + + flatten1 = Node(split.out0, Flatten, {}, name='flatten1') + perm = Node(flatten1, PermuteRandom, {'seed': 123}, name='perm') + unflatten1 = Node(perm, Reshape, {'output_dims': (1, 10, 10)}, name='unflatten1') + + conv = Node(split.out1, + RNVPCouplingBlock, + {'subnet_constructor': F_conv, 'clamp': 1.0}, + conditions=cond, + name='conv') + + flatten2 = Node(conv, Flatten, {}, name='flatten2') + + linear = Node(flatten2, + RNVPCouplingBlock, + {'subnet_constructor': F_fully_connected, 'clamp': 1.0}, + name='linear') + + unflatten2 = Node(linear, Reshape, {'output_dims': (2, 10, 10)}, name='unflatten2') + concat = Node([unflatten1.out0, unflatten2.out0], Concat, {'dim': 0}, name='concat') + haar = Node(concat, HaarDownsampling, {}, name='haar') + out = OutputNode(haar, name='output') + + graph = GraphINN([inp, cond, split, flatten1, perm, unflatten1, conv, flatten2, linear, unflatten2, concat, haar, out]) + graph.plot(path=self.plotdir, filename=self.plot_name) + + + self.assertTrue(os.path.exists(self.file_path)) + self.assertTrue(os.path.exists(self.file_path + ".pdf")) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_splines.py b/tests/test_splines.py index 2c9a691..d440a75 100644 --- a/tests/test_splines.py +++ b/tests/test_splines.py @@ -82,6 +82,8 @@ class TestUnconditionalCoupling: fm.GINCouplingBlock, fm.AffineCouplingOneSided, fm.RationalQuadraticSpline, + fm.ElementwiseRationalQuadraticSpline, + fm.LinearSpline, ] scenarios = [ ((32, 3), "dense", [16, 32, 16], dict()), @@ -105,6 +107,50 @@ def test_forward_backward(self): reconstruction, backward_logdet = inn(latent, rev=True) assert reconstruction.shape == sample_data.shape assert backward_logdet.dim() == 1 + assert torch.allclose(sample_data, reconstruction, atol=1e-5, + rtol=1e-3), f"MSE: {F.mse_loss(sample_data, reconstruction)}" + assert torch.allclose(forward_logdet, -backward_logdet, atol=1e-5, + rtol=1e-3), f"MSE: {F.mse_loss(forward_logdet, -backward_logdet)}" + +class TestConditionalCoupling: + + couplings = [ + fm.AllInOneBlock, + fm.NICECouplingBlock, + fm.RNVPCouplingBlock, + fm.GLOWCouplingBlock, + fm.GINCouplingBlock, + fm.AffineCouplingOneSided, + fm.RationalQuadraticSpline, + fm.ElementwiseRationalQuadraticSpline, + fm.LinearSpline, + ] + scenarios = [ + ((32, 3), (32, 4), "dense", [16, 32, 16], dict()), + ((8, 3, 8, 8), (8, 4, 8, 8), "conv", [8, 8, 8], dict(kernel_size=1)), + ((8, 3, 8, 8), (8, 4, 8, 8), "conv", [4, 6, 4], dict(kernel_size=3)), + ] + + def test_forward_backward(self): + for coupling_type in self.couplings: + for batch_shape, condition_shape, network_kind, network_widths, network_kwargs in self.scenarios: + subnet_constructor = SubnetFactory(kind=network_kind, widths=network_widths, **network_kwargs) + inn = ff.SequenceINN(*batch_shape[1:]) + inn.append(coupling_type, + subnet_constructor=subnet_constructor, + cond_shape=condition_shape[1:], + cond=0) + + sample_data = torch.randn(*batch_shape) + sample_cond = torch.randn(*condition_shape) + + latent, forward_logdet = inn(sample_data, (sample_cond,)) + assert latent.shape == sample_data.shape + assert forward_logdet.dim() == 1 + + reconstruction, backward_logdet = inn(latent, (sample_cond,), rev=True) + assert reconstruction.shape == sample_data.shape + assert backward_logdet.dim() == 1 assert torch.allclose(sample_data, reconstruction, atol=1e-5, rtol=1e-3), f"MSE: {F.mse_loss(sample_data, reconstruction)}"