Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions FrEIA/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .invertible import Invertible
23 changes: 23 additions & 0 deletions FrEIA/core/invertible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

from abc import ABC
import torch.nn as nn

from typing import Any

from typing import TypeVar

T = TypeVar("T")


class Invertible(ABC, nn.Module):
def forward(self, *args: T, **kwargs: T) -> Any:
raise NotImplementedError

def inverse(self, *args, **kwargs):
raise NotImplementedError

def __call__(self, *args, rev = False, **kwargs):
if not rev:
return self.forward(*args, **kwargs)

return self.inverse(*args, **kwargs)
Empty file added FrEIA/flows/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions FrEIA/flows/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

from freia.core import Invertible

class Flow(Invertible):
def __init__(self, transform, distribution):
self.transform = transform
self.distribution = distribution

def forward(self, x):
z, logdet = self.transform.forward(x)

logp = self.distribution.log_prob(z)

nll = -(logp + logdet)

return z, nll

def sample_transform(self, size, temperature):
z = self.distribution.sample(size, temperature)

x, _ = self.transform.inverse(z)

return x


class RecurrentFlow(Flow):
def forward(self, x):
z = x
logdet = None
for t in range(...):
z, logdet = self.transform.forward(z, t)
2 changes: 2 additions & 0 deletions FrEIA/splits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .even import EvenSplit
18 changes: 18 additions & 0 deletions FrEIA/splits/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@

from FrEIA.core import Invertible

from typing import Tuple

import torch


class Split(Invertible):
def __init__(self, dim: int = 1):
super().__init__()
self.dim = dim

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
14 changes: 14 additions & 0 deletions FrEIA/splits/even.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

from .base import Split

from typing import Tuple

import torch


class EvenSplit(Split):
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.split(x, 2, dim=1)

def inverse(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return torch.cat((x1, x2), dim=1)
3 changes: 3 additions & 0 deletions FrEIA/splits/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@


# class RandomSplit()
2 changes: 2 additions & 0 deletions FrEIA/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .base import Transform
25 changes: 25 additions & 0 deletions FrEIA/transforms/affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

from .base import Transform

import torch

from .coupling import CouplingTransform


class AffineTransform(CouplingTransform):

def __init__(self):
parameter_counts = {...}
super().__init__(parameter_counts=parameter_counts)

def transform_parameters(self, **parameters):
parameters["a"] = torch.exp(parameters["a"])

def _forward(self, x: torch.Tensor, **parameters) -> torch.Tensor:
parameters = self.get_parameters()
a, b = parameters["a"], parameters["b"]
return a * x + b, torch.log(a)

def _inverse(self, z: torch.Tensor, **parameters) -> torch.Tensor:
a, b = parameters["a"], parameters["b"]
return (z - b) / a, -torch.log(a)
46 changes: 46 additions & 0 deletions FrEIA/transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

from freia.core import Invertible

import torch


WithJacobian = tuple[torch.Tensor, torch.Tensor]



class Transform(Invertible):
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError

def inverse(self, z: torch.Tensor) -> torch.Tensor:
raise NotImplementedError




@Parameterized(scale=1, shift=1)
class AffineTransform(Transform):
def forward(self, x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
return scale * x + shift

def inverse(self, z: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
return (z - shift) / scale


# class SplineTransform(Transform):
# def forward(self, x: torch.Tensor, edges: torch.Tensor):
# assert edges.shape == (..., self.bins)
# pass


class Parameterized:
def __init__(self, **parameter_counts):
self.parameter_counts = parameter_counts

def __call__(self, cls):

cls.forward = forward
cls.inverse = inverse



73 changes: 73 additions & 0 deletions FrEIA/transforms/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

from .base import Transform
from freia.splits import EvenSplit

import torch
import torch.nn as nn


class Spline(Transform):
def __init__(self, affine, inner_spline):
...

def forward(self, x: torch.Tensor, *, condition: torch.Tensor, **kwargs) -> WithJacobian:
x[out] = affine(x[out])
x[out] = inner_spline(x[out])


class Spline(CouplingTransform):
def _forward(self):
x[in] = self._spline(...)
x[out] = self._affine(...)



class CouplingTransform(Transform):
def __init__(self, transform1, transform2, subnet_constructor, split=EvenSplit(dim=1)):
self.split = split
self.subnet1 = subnet_constructor(...)
self.subnet2 = subnet_constructor(...)

def split_parameters(self, parameters: torch.Tensor) -> dict:
pc = self.parameter_counts
parameters = torch.split(parameters, list(pc.values()), dim=1)

return dict(zip(pc.keys(), parameters))

def transform_parameters(self, parameters: dict[torch.Tensor]) -> None:
pass

def get_parameters(self, *args, **kwargs) -> dict:
raise NotImplementedError

def get_parameters(self, u: torch.Tensor, subnet: nn.Module) -> dict:

parameters = subnet(u)
parameters = self.split_parameters(parameters)
should_be_none = self.transform_parameters(**parameters)
if should_be_none is not None:
warnings.warn(...)

return parameters


def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> torch.Tensor:
x1, x2 = self.split.forward(x)



parameters = self.get_parameters(u=x2, subnet=self.subnet1)
z1, logdet1 = self.transform1.forward(x1, **parameters)
parameters = self.get_parameters(u=z1, subnet=self.subnet2)
z2, logdet2 = self.transform2(x2, **parameters)

z = self.split.inverse(z1, z2)
logdet = logdet1 + logdet2

return z, logdet





my_single_coupling = CouplingTransform(transform1=AffineTransform(...), transform2=None)
12 changes: 12 additions & 0 deletions FrEIA/transforms/identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

from .base import Transform

import torch


class IdentityTransform(Transform):
def forward(self, x: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian:
return x, 0

def inverse(self, z: torch.Tensor, **parameters: torch.Tensor) -> WithJacobian:
return z, 0
59 changes: 59 additions & 0 deletions FrEIA/transforms/ode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

from .base import Transform

import torch

from scipy.ode import solve_ode



class Parameterized(nn.Module):
def __init__(self, *, subnet_constructor, parameter_counts):
super().__init__()
self.subnet = ...
self.parameter_counts = ...
self.transform = transform_cls

def __call__(self, *args, **kwargs):
self.transform = transform_cls(*args, **kwargs)

return self

def forward(self):
parameters = self.subnet(...)
return self.transform(x, parameters)


@Parameterized
class ODETransform(Transform):
def __init__(self, integration_steps: int = 10):
super().__init__()
self.integration_steps = integration_steps

def forward(self, x: torch.Tensor, **parameters) -> tuple[torch.Tensor, torch.Tensor]:
return euler(x, v, dt)

# ode integration
dt = 1 / self.integration_steps
for _ in range(self.integration_steps):
parameters = self.get_parameters()
v = parameters["v"]
x = euler(x, v, dt)

return x

ODETransform = Parameterized(ODETransform)





ode = ODETransform()




def euler(x, v, dt):
return x + v * dt


Loading