Skip to content

Commit

Permalink
Add more comments and examples README
Browse files Browse the repository at this point in the history
  • Loading branch information
mcbal committed Oct 14, 2021
1 parent 99bec96 commit 8f8a4fc
Show file tree
Hide file tree
Showing 6 changed files with 4,714 additions and 125 deletions.
3 changes: 3 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Examples

The scripts in this folder reproduce the plots appearing in the blog post [Transformers from Spin Models: Approximate Free Energy Minimization](https://mcbal.github.io/post/transformers-from-spin-models-approximate-free-energy-minimization/).
13 changes: 11 additions & 2 deletions examples/model_beta_sweep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Sweep across inverse temperature `beta` for fixed inputs to a `VectorSpinModel` module
# and plot `phi` and its derivatives.
# and return plots of `phi` and its derivatives as an animated GIF.

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -13,10 +13,16 @@
beta_min, beta_max, beta_num_steps = -0.6, 1.0, 100 # log10-space
plot_values_cutoff = 500

# Fixed inputs and initial auxiliary variable.
x = (torch.randn(1, num_spins, dim) / np.sqrt(dim)).requires_grad_()
t0 = 0.5*torch.ones(1)


##############
# PLOT STUFF #
##############


def filter_array(a, threshold=plot_values_cutoff):
idx = np.where(np.abs(a) > threshold)
a[idx] = np.nan
Expand All @@ -27,17 +33,20 @@ def simple_update(frame, fig, axs):

print(f'{frame:.4f} / {10**beta_max}')

# Setup vector-spin model.
model = VectorSpinModel(
num_spins=num_spins,
dim=dim,
beta=frame,
)
out = model(x, t0=t0, return_afe=True, return_magnetizations=True)
# Run forward pass.
out = model(x, t0=t0, return_afe=True)

t_star = out.t_star[0][0].detach().numpy()
afe_star = out.afe[0][0].detach().numpy()
t_min, t_max, t_step = 0.0, 3.0, 0.001
t_range = torch.arange(t_min, t_max, t_step)[:, None]

phis = np.array(model._phi(t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
grad_phis = np.array(model._jac_phi(
t_range, x[:1, :, :].repeat(t_range.numel(), 1, 1)).detach().numpy())
Expand Down
6 changes: 3 additions & 3 deletions examples/model_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
# Run backward on sum of free energies across batch dimension.
out.afe.sum().backward()

###############################################################
# Plot internally-used function `phi(t)` and its derivatives. #
###############################################################
##############
# PLOT STUFF #
##############

if x.size(0) == 1 and t0.numel() == 1:

Expand Down
4 changes: 4 additions & 0 deletions examples/model_two_spins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
f'✨ t_star {out.t_star.detach()}\n✨ magnetizations: {out.magnetizations}\n✨ approximate free energy: {out.afe.detach()}'
)

##############
# PLOT STUFF #
##############


def filter_array(a, threshold=50):
idx = np.where(np.abs(a) > threshold)
Expand Down
120 changes: 0 additions & 120 deletions notebooks/computational_graphs.ipynb

This file was deleted.

Loading

0 comments on commit 8f8a4fc

Please sign in to comment.