Skip to content

Commit

Permalink
More dump
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbal committed Oct 3, 2021
1 parent 7a60d73 commit efb24c1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 17 deletions.
42 changes: 25 additions & 17 deletions afem/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from einops import repeat

from .modules import FeedForward
from .rootfind import RootFind
from .utils import batch_eye, batch_eye_like, batch_jacobian, default, exists

Expand Down Expand Up @@ -55,7 +56,8 @@ def __init__(
))
if J_add_external:
# self._J_ext = nn.Parameter(torch.empty(dim, dim).uniform_(-1.0 / np.sqrt(dim**2), 1.0 / np.sqrt(dim**2)))
self._J_ext = nn.Parameter(torch.ones(num_spins, num_spins))
# self._J_ext = nn.Parameter(torch.ones(num_spins, num_spins))
self.edge_nn = FeedForward(2*dim, dim_out=1, dropout=0.1)
# self.to_q = nn.Linear(dim, dim, bias=False)
# self.to_k = nn.Linear(dim, dim, bias=False)
self.J_add_external = J_add_external
Expand All @@ -81,31 +83,37 @@ def J(self, h):
bsz, num_spins, dim, J = h.size(0), h.size(1), h.size(2), self._J

J = repeat(J, 'i j -> b i j', b=bsz)
h = h.detach()

# print(torch.linalg.norm(J).item(), torch.linalg.norm(h[0]).item(), 'inside J')
# breakpoint()
if self.J_add_external:
# q, k = self.to_q(h), self.to_k(h)
# inside = torch.einsum('b i f, b j f -> b i j', q, k) / np.sqrt(dim**2*num_spins)
# ext = torch.tanh(inside)
# print(torch.cdist(h, h).shape, self._J_ext.shape)
print(self._J_ext**2)

# breakpoint()
ext = (-0.5 / (self._J_ext**2).unsqueeze(0) * torch.cdist(h/np.sqrt(dim),
h/np.sqrt(dim))).exp() / np.sqrt(num_spins)**2
print(
torch.linalg.norm(J[0]).item(),
torch.linalg.norm(torch.cdist(h/np.sqrt(dim),
h/np.sqrt(dim))).item(),
torch.linalg.norm(ext[0]).item(),
)
print(ext)
X1 = h.unsqueeze(1)
Y1 = h.unsqueeze(2)
# print(X1.shape, Y1.shape)
X2 = X1.repeat(1, h.shape[1], 1, 1)
Y2 = Y1.repeat(1, 1, h.shape[1], 1)
# print(X2.shape, X2.shape)
Z = torch.cat([X2, Y2], -1)
y = Z.view(h.shape[0], h.shape[1], h.shape[1], Z.shape[-1])
# print(y.shape)
# breakpoint()
J = J + torch.tanh(ext)
print(
torch.linalg.norm(J[0]).item(),
)

ext = self.edge_nn(y).squeeze(-1) / np.sqrt(num_spins*dim)

# print(
# torch.linalg.norm(J[0]).item(),
# torch.linalg.norm(ext[0]).item(),
# )

J = J + ext
# print(
# torch.linalg.norm(J[0]).item(),
# )

if self.J_symmetric:
J = 0.5 * (J + J.permute(0, 2, 1))
Expand Down
16 changes: 16 additions & 0 deletions afem/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ def forward(self, x):
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
x = x / n * self.scale
return x


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=1, dropout=0.1, dense=nn.Linear):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.Sequential(
dense(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
dense(inner_dim, dim_out)
)

def forward(self, x):
return self.net(x)

0 comments on commit efb24c1

Please sign in to comment.