diff --git a/FrEIA/core/__init__.py b/FrEIA/core/__init__.py new file mode 100644 index 0000000..9ce0f83 --- /dev/null +++ b/FrEIA/core/__init__.py @@ -0,0 +1,2 @@ + +from .invertible import Invertible diff --git a/FrEIA/core/invertible.py b/FrEIA/core/invertible.py new file mode 100644 index 0000000..6a173fd --- /dev/null +++ b/FrEIA/core/invertible.py @@ -0,0 +1,23 @@ + +from abc import ABC +import torch.nn as nn + +from typing import Any + +from typing import TypeVar + +T = TypeVar("T") + + +class Invertible(ABC, nn.Module): + def forward(self, *args: T, **kwargs: T) -> Any: + raise NotImplementedError + + def inverse(self, *args, **kwargs): + raise NotImplementedError + + def __call__(self, *args, rev = False, **kwargs): + if not rev: + return self.forward(*args, **kwargs) + + return self.inverse(*args, **kwargs) diff --git a/FrEIA/flows/__init__.py b/FrEIA/flows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/FrEIA/flows/base.py b/FrEIA/flows/base.py new file mode 100644 index 0000000..7961acd --- /dev/null +++ b/FrEIA/flows/base.py @@ -0,0 +1,31 @@ + +from freia.core import Invertible + +class Flow(Invertible): + def __init__(self, transform, distribution): + self.transform = transform + self.distribution = distribution + + def forward(self, x): + z, logdet = self.transform.forward(x) + + logp = self.distribution.log_prob(z) + + nll = -(logp + logdet) + + return z, nll + + def sample_transform(self, size, temperature): + z = self.distribution.sample(size, temperature) + + x, _ = self.transform.inverse(z) + + return x + + +class RecurrentFlow(Flow): + def forward(self, x): + z = x + logdet = None + for t in range(...): + z, logdet = self.transform.forward(z, t) \ No newline at end of file diff --git a/FrEIA/splits/__init__.py b/FrEIA/splits/__init__.py new file mode 100644 index 0000000..7dbb05f --- /dev/null +++ b/FrEIA/splits/__init__.py @@ -0,0 +1,2 @@ + +from .even import EvenSplit diff --git a/FrEIA/splits/base.py b/FrEIA/splits/base.py new file mode 100644 index 0000000..faf61d0 --- /dev/null +++ b/FrEIA/splits/base.py @@ -0,0 +1,18 @@ + +from FrEIA.core import Invertible + +from typing import Tuple + +import torch + + +class Split(Invertible): + def __init__(self, dim: int = 1): + super().__init__() + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + raise NotImplementedError diff --git a/FrEIA/splits/even.py b/FrEIA/splits/even.py new file mode 100644 index 0000000..9feaf2c --- /dev/null +++ b/FrEIA/splits/even.py @@ -0,0 +1,14 @@ + +from .base import Split + +from typing import Tuple + +import torch + + +class EvenSplit(Split): + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.split(x, 2, dim=1) + + def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return torch.cat((x1, x2), dim=1) diff --git a/FrEIA/splits/random.py b/FrEIA/splits/random.py new file mode 100644 index 0000000..fda598e --- /dev/null +++ b/FrEIA/splits/random.py @@ -0,0 +1,3 @@ + + +# class RandomSplit() diff --git a/FrEIA/transforms/__init__.py b/FrEIA/transforms/__init__.py new file mode 100644 index 0000000..44c2a6e --- /dev/null +++ b/FrEIA/transforms/__init__.py @@ -0,0 +1,2 @@ + +from .base import Transform diff --git a/FrEIA/transforms/affine.py b/FrEIA/transforms/affine.py new file mode 100644 index 0000000..34324d0 --- /dev/null +++ b/FrEIA/transforms/affine.py @@ -0,0 +1,25 @@ + +from .base import Transform + +import torch + +from .coupling import CouplingTransform + + +class AffineTransform(CouplingTransform): + + def __init__(self): + parameter_counts = {...} + super().__init__(parameter_counts=parameter_counts) + + def transform_parameters(self, **parameters): + parameters["a"] = torch.exp(parameters["a"]) + + def _forward(self, x: torch.Tensor, **parameters) -> torch.Tensor: + parameters = self.get_parameters() + a, b = parameters["a"], parameters["b"] + return a * x + b, torch.log(a) + + def _inverse(self, z: torch.Tensor, **parameters) -> torch.Tensor: + a, b = parameters["a"], parameters["b"] + return (z - b) / a, -torch.log(a) diff --git a/FrEIA/transforms/base.py b/FrEIA/transforms/base.py new file mode 100644 index 0000000..341ec21 --- /dev/null +++ b/FrEIA/transforms/base.py @@ -0,0 +1,46 @@ + +from freia.core import Invertible + +import torch + + +WithJacobian = tuple[torch.Tensor, torch.Tensor] + + + +class Transform(Invertible): + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def inverse(self, z: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + + + +@Parameterized(scale=1, shift=1) +class AffineTransform(Transform): + def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return scale * x + shift + + def inverse(self, z: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return (z - shift) / scale + + +# class SplineTransform(Transform): +# def forward(self, x: torch.Tensor, edges: torch.Tensor): +# assert edges.shape == (..., self.bins) +# pass + + +class Parameterized: + def __init__(self, **parameter_counts): + self.parameter_counts = parameter_counts + + def __call__(self, cls): + + cls.forward = forward + cls.inverse = inverse + + + diff --git a/FrEIA/transforms/coupling.py b/FrEIA/transforms/coupling.py new file mode 100644 index 0000000..f8c521b --- /dev/null +++ b/FrEIA/transforms/coupling.py @@ -0,0 +1,73 @@ + +from .base import Transform +from freia.splits import EvenSplit + +import torch +import torch.nn as nn + + +class Spline(Transform): + def __init__(self, affine, inner_spline): + ... + + def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian: + x[out] = affine(x[out]) + x[out] = inner_spline(x[out]) + + +class Spline(CouplingTransform): + def _forward(self): + x[in] = self._spline(...) + x[out] = self._affine(...) + + + +class CouplingTransform(Transform): + def __init__(self, transform1, transform2, subnet_constructor, split=EvenSplit(dim=1)): + self.split = split + self.subnet1 = subnet_constructor(...) + self.subnet2 = subnet_constructor(...) + + def split_parameters(self, parameters: torch.Tensor) -> dict: + pc = self.parameter_counts + parameters = torch.split(parameters, list(pc.values()), dim=1) + + return dict(zip(pc.keys(), parameters)) + + def transform_parameters(self, parameters: dict[torch.Tensor]) -> None: + pass + + def get_parameters(self, *args, **kwargs) -> dict: + raise NotImplementedError + + def get_parameters(self, u: torch.Tensor, subnet: nn.Module) -> dict: + + parameters = subnet(u) + parameters = self.split_parameters(parameters) + should_be_none = self.transform_parameters(**parameters) + if should_be_none is not None: + warnings.warn(...) + + return parameters + + + def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> torch.Tensor: + x1, x2 = self.split.forward(x) + + + + parameters = self.get_parameters(u=x2, subnet=self.subnet1) + z1, logdet1 = self.transform1.forward(x1, **parameters) + parameters = self.get_parameters(u=z1, subnet=self.subnet2) + z2, logdet2 = self.transform2(x2, **parameters) + + z = self.split.inverse(z1, z2) + logdet = logdet1 + logdet2 + + return z, logdet + + + + + +my_single_coupling = CouplingTransform(transform1=AffineTransform(...), transform2=None) diff --git a/FrEIA/transforms/identity.py b/FrEIA/transforms/identity.py new file mode 100644 index 0000000..af14d2c --- /dev/null +++ b/FrEIA/transforms/identity.py @@ -0,0 +1,12 @@ + +from .base import Transform + +import torch + + +class IdentityTransform(Transform): + def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian: + return x, 0 + + def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian: + return z, 0 diff --git a/FrEIA/transforms/ode.py b/FrEIA/transforms/ode.py new file mode 100644 index 0000000..4f658f7 --- /dev/null +++ b/FrEIA/transforms/ode.py @@ -0,0 +1,59 @@ + +from .base import Transform + +import torch + +from scipy.ode import solve_ode + + + +class Parameterized(nn.Module): + def __init__(self, *, subnet_constructor, parameter_counts): + super().__init__() + self.subnet = ... + self.parameter_counts = ... + self.transform = transform_cls + + def __call__(self, *args, **kwargs): + self.transform = transform_cls(*args, **kwargs) + + return self + + def forward(self): + parameters = self.subnet(...) + return self.transform(x, parameters) + + +@Parameterized +class ODETransform(Transform): + def __init__(self, integration_steps: int = 10): + super().__init__() + self.integration_steps = integration_steps + + def forward(self, x: torch.Tensor, **parameters) -> tuple[torch.Tensor, torch.Tensor]: + return euler(x, v, dt) + + # ode integration + dt = 1 / self.integration_steps + for _ in range(self.integration_steps): + parameters = self.get_parameters() + v = parameters["v"] + x = euler(x, v, dt) + + return x + +ODETransform = Parameterized(ODETransform) + + + + + +ode = ODETransform() + + + + +def euler(x, v, dt): + return x + v * dt + + diff --git a/playground.py b/playground.py new file mode 100644 index 0000000..e716785 --- /dev/null +++ b/playground.py @@ -0,0 +1,185 @@ +from functools import wraps + +import torch + +import torch.distributions + +from typing import Dict + + +class Transform: + def __init__(self): + print(f"{self.__class__.__name__} __init__") + + def __call__(self, *args, **kwargs): + print(f"{self.__class__.__name__} __call__") + + +from typing import Callable, Union + +class Parameter: + def __init__(self, count: Union[int, Callable[[Transform], int]]): + self.count = count + + def initialize(self, transform: Transform): + if isinstance(self.count, Callable): + self.count = self.count(transform) + + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class Real(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return unconstrained + + +class Positive(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return torch.exp(unconstrained) + + +class Increasing(Parameter): + def constrain(self, unconstrained: torch.Tensor) -> torch.Tensor: + return unconstrained[:, 0] + torch.cumsum(torch.exp(unconstrained[:, 1:]), dim=1) + + + +class Coupling(Transform): + def __init__(self, split, transform, subnet, **parameters: Parameter): + super().__init__() + self.split = split + self.transform = transform + # TODO: 2 subnets? or just singular coupling? + self.subnet = subnet + + test_subnet(...) + try: + self.subnet[-1].weight.data.zero_() + self.subnet[-1].bias.data.zero_() + except Exception: + dummy_output = ... + + if not zero: + warnings.warn(...) + + self._parameters = parameters + + @property + def parameter_names(self): + return self._parameters.keys() + + @property + def parameter_counts(self): + return [p.count for p in self._parameters.values()] + + def get_parameters(self, condition: torch.Tensor) -> Dict[str, torch.Tensor]: + parameters = self.subnet(condition) + parameters = torch.split(parameters, self.parameter_counts, dim=1) + parameters = [p.constrain(u) for (p, u) in zip(self._parameters.keys(), parameters)] + parameters = dict(zip(self.parameter_names, parameters)) + + return parameters + + def transform_forward(self): + pass + + def transform_inverse(self): + pass + + def forward(self, x: torch.Tensor, rev: bool = False, jac: bool = True) -> torch.Tensor: + x1, x2 = self.split.forward(x) + parameters = self.get_parameters(x2) + z1 = self.transform_forward(x1, **parameters) + parameters = self.get_parameters(z1) + z2 = self.transform.forward(x2, **parameters) + + z = self.split.inverse(z1, z2) + + return z + + def inverse(self, z: torch.Tensor) -> torch.Tensor: + z1, z2 = self.split.forward(z) + parameters = self.get_parameters(z1) + x2 = self.transform.inverse(z2, **parameters) + parameters = self.get_parameters(x2) + x1 = self.transform.inverse(z1, **parameters) + + x = self.split.inverse(x1, x2) + + return x + + +from FrEIA.splits import EvenSplit + + +def parameterize(**parameters): + def wrap(cls): + @wraps(cls) + def construct(*args, split=EvenSplit(), subnet_constructor, **kwargs): + transform = cls(*args, **kwargs) + for p in parameters.values(): + # initialize dynamic parameters + p.initialize(transform) + + dims_in = ... + dims_out = ... + subnet = subnet_constructor(dims_in, dims_out) + + + return Coupling(split=split, transform=transform, subnet=subnet, **parameters) + return construct + + return wrap + + +# @parameterize(scale=Positive(1), shift=Real(1)) +class AffineTransform(Coupling): + def __init__(self): + super().__init__(scale=Positive(1), shift=Real(1)) + def _forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, rev) -> torch.Tensor: + return scale * x + shift + + def _inverse(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor: + return (x - shift) / scale + + +@parameterize(x_edges=Increasing(lambda t: t.bins), y_edges=Increasing(lambda t: t.bins), deltas=Increasing(lambda t: t.bins - 1)) +class RQSpline(Transform): + def __init__(self, bins: int): + super().__init__(x_edges=Increasing(bins)) + self.bins = bins + + def forward(self, x: torch.Tensor, x_edges: torch.Tensor, y_edges: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor: + return torch.zeros(*x.shape) + + def blub(self): + pass + + def inverse(self): + pass + + +def subnet_constructor(dims_in, dims_out): + print("subnet_constructor") + def subnet(x): + return torch.zeros(x.shape[0], dims_out) + + return subnet + + +t = AffineTransform(subnet_constructor=subnet_constructor) +t.blub() +print(type(t.transform.blub())) + +x = None + +t(x) + + +t = RQSpline(bins=8, subnet_constructor=subnet_constructor) + +t(x) + +print(t.parameter_counts) +