Skip to content

Commit

Permalink
Dirty dump before flow refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbal committed Oct 5, 2021
1 parent efb24c1 commit ca4c98c
Show file tree
Hide file tree
Showing 11 changed files with 923 additions and 30 deletions.
6 changes: 3 additions & 3 deletions afem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
if J_add_external:
# self._J_ext = nn.Parameter(torch.empty(dim, dim).uniform_(-1.0 / np.sqrt(dim**2), 1.0 / np.sqrt(dim**2)))
# self._J_ext = nn.Parameter(torch.ones(num_spins, num_spins))
self.edge_nn = FeedForward(2*dim, dim_out=1, dropout=0.1)
self.edge_nn = FeedForward(2*dim, mult=0.5, dim_out=1, dropout=0.1)
# self.to_q = nn.Linear(dim, dim, bias=False)
# self.to_k = nn.Linear(dim, dim, bias=False)
self.J_add_external = J_add_external
Expand Down Expand Up @@ -99,7 +99,7 @@ def J(self, h):
Y2 = Y1.repeat(1, 1, h.shape[1], 1)
# print(X2.shape, X2.shape)
Z = torch.cat([X2, Y2], -1)
y = Z.view(h.shape[0], h.shape[1], h.shape[1], Z.shape[-1])
y = Z.reshape(h.shape[0], h.shape[1], h.shape[1], Z.shape[-1])
# print(y.shape)
# breakpoint()

Expand Down Expand Up @@ -255,7 +255,7 @@ def log_prob(self, t, h, beta=None):
# print(vals)
# breakpoint()
return (
- 0.5 * torch.logdet(V)
-0.5 * torch.logdet(V)
+ beta / (4.0 * self.dim) * torch.einsum('b i f, b i j, b j f -> b', h, V_inv, h)
)[:, None]

Expand Down
26 changes: 26 additions & 0 deletions afem/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# with torch.no_grad():
# import matplotlib.pyplot as plt

# def filter_array(a, threshold=1e2):
# idx = np.where(np.abs(a) > threshold)
# a[idx] = np.nan
# return a
# # Pick a range and resolution for `t`.
# t_range = torch.arange(0.0, 3.0, 0.001)[:, None]
# # Calculate function evaluations for every point on grid and plot.
# out = np.array(self._phi(t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
# out_grad = np.array(self._grad_t_phi(
# t_range, h[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
# f, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
# ax1.set_title(f"(Hopefully) Found root of phi'(t) at t = {t[0][0].detach().numpy()}")
# ax1.plot(t_range.numpy().squeeze(), filter_array(out), 'r-')
# ax1.axvline(x=t[0].detach().numpy())
# ax1.set_ylabel("phi(t)")
# ax2.plot(t_range.numpy().squeeze(), filter_array(out_grad), 'r-')
# ax2.axvline(x=t[0].detach().numpy())
# ax2.axhline(y=0.0)
# ax2.set_xlabel('t')
# ax2.set_ylabel("phi'(t)")
# # plt.show()
# from datetime import datetime
# plt.savefig(f'{datetime.now()}.png')
58 changes: 58 additions & 0 deletions afem/rootfind copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch
import torch.nn as nn
import torch.autograd as autograd

from .solvers import newton
from .utils import filter_kwargs, remove_kwargs


class RootFind(nn.Module):
"""Differentiable root-solving using implicit differentiation.
See https://implicit-layers-tutorial.org/introduction/ and https://github.com/locuslab/deq.
"""
_default_kwargs = {
'solver_fwd_max_iter': 30,
'solver_fwd_tol': 1e-4,
'solver_bwd_max_iter': 30,
'solver_bwd_tol': 1e-4,
}

def __init__(self, fun, solver=newton, **kwargs):
super().__init__()
self.fun = fun
self.solver = solver
self.kwargs = self._default_kwargs
self.kwargs.update(**kwargs)

def _root_find(self, z0, x, *args, **kwargs):

# Compute forward pass: find root of function outside autograd tape.
with torch.no_grad():
z_root = self.solver(
lambda z: self.fun(z, x, *args, **remove_kwargs(kwargs, 'solver_')),
z0,
**filter_kwargs(kwargs, 'solver_fwd_'),
)['result']
new_z_root = z_root

if self.training:
# Re-engage autograd tape (no-op in terms of value of z).
new_z_root = z_root - self.fun(z_root.requires_grad_(), x, *args, **remove_kwargs(kwargs, 'solver_'))

# Set up backward hook for root-solving in backward pass.
z_bwd = new_z_root.clone().detach().requires_grad_()
fun_bwd = self.fun(z_bwd, x, *args, **remove_kwargs(kwargs, 'solver_'))

def backward_hook(grad):
return self.solver(
lambda y: autograd.grad(fun_bwd, z_bwd, y, retain_graph=True, create_graph=True)[0] + grad,
torch.zeros_like(grad), **filter_kwargs(kwargs, 'solver_bwd_')
)['result']

new_z_root.register_hook(backward_hook)

return new_z_root

def forward(self, z0, x, *args, **kwargs):
return self._root_find(z0, x, *args, **{**self.kwargs, **kwargs})
40 changes: 40 additions & 0 deletions afem/sgld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.distributions import Normal
from torch.optim.optimizer import Optimizer, required
import numpy as np


class SGLD(Optimizer):
"""
Barely modified version of pytorch SGD to implement SGLD
"""

def __init__(self, params, lr=required, addnoise=True):
defaults = dict(lr=lr, addnoise=addnoise)
super(SGLD, self).__init__(params, defaults)

def step(self, lr=None, add_noise=True):
"""
Performs a single optimization step.
"""
loss = None

for group in self.param_groups:
if lr:
group['lr'] = lr
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if group['addnoise']:
size = d_p.size()
langevin_noise = Normal(
torch.zeros(size),
torch.ones(size) / np.sqrt(group['lr'])
)
p.data.add_(-group['lr'],
d_p + langevin_noise.sample())
else:
p.data.add_(-group['lr'], d_p)

return loss
2 changes: 1 addition & 1 deletion afem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def batch_jacobian(f, x, create_graph=False, swapaxes=True):
https://discuss.pytorch.org/t/jacobian-functional-api-batch-respecting-jacobian/84571
"""
def _f_sum(x):
return f(x).sum(dim=0)
return f(x).sum(dim=0) if callable(f) else f.sum(dim=1)
jac_f = autograd.functional.jacobian(_f_sum, x, create_graph=create_graph)
return jac_f.swapaxes(1, 0) if swapaxes else jac_f
Loading

0 comments on commit ca4c98c

Please sign in to comment.