Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions dwave/plugins/torch/models/boltzmann_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ def quasi_objective(
'each other.'
)
raise ValueError(err_msg)
# NOTE: this method relies on hidden units being disconnected. The calculations
# depend on this assumption in **two** ways. The obvious one is marginalization. The
# less obvious dependence is the linearity of expectation and sufficient statistics.
# Because hidden units are disconnected, we can average their spins before computing
# the sufficient statistics, which is then passed into the quasi objective function.
obs = self._compute_expectation_disconnected(s_observed)
elif kind == "sampling":
obs = self._approximate_expectation_sampling(
Expand Down Expand Up @@ -436,10 +441,10 @@ def _compute_effective_field(self, padded: torch.Tensor) -> torch.Tensor:
# by the corresponding edges. Transforming this contribution vector by a
# cumulative sum yields cumulative contributions to effective fields.
# Differencing removes the extra gobbledygook.
contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx]
contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx].detach()
cumulative_contribution = contribution.cumsum(1)
# Don't forget to add the linear fields!
h_eff = self._linear[self.hidden_idx] + cumulative_contribution[
h_eff = self._linear[self.hidden_idx].detach() + cumulative_contribution[
:, self._bin_idx
].diff(dim=1, prepend=torch.zeros(bs, device=padded.device).unsqueeze(1))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
Fix the (automatic) gradient computation when hidden units are present. The issue was that the
parameters, linear and quadratic weights, were used in the marginalization without being
detached from the computation graph. The fix was to detach the parameters when computing
effective fields.
44 changes: 42 additions & 2 deletions tests/test_boltzmann_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_compute_effective_field_unordered(self):
# = 0.13 * [-1] - 0.17 * [1] + 0.4 = .1

grbm._linear.data = torch.tensor([-0.1, -0.2, 0.4, 0.2])
grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17 ])
grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17])
padded = torch.tensor([[-1.0, float("nan"), float("nan"), 1.0]])
h_eff = grbm._compute_effective_field(padded)
self.assertTrue(torch.allclose(h_eff.data, torch.tensor([-0.5000, 0.1000]), atol=1e-6))
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_sample_return_sampleset(self):
self.assertEqual(4, len(sampleset.variables))
self.assertEqual(set(grbm.nodes), set(sampleset.variables))

def test_objective(self):
def test_quasi_objective(self):
# Create a triangle graph with an additional dangling vertex
self.nodes = list("abcd")
self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]]
Expand Down Expand Up @@ -356,6 +356,46 @@ def test_objective(self):
self.assertEqual(-1, grbm.quasi_objective(s1, s2).item())
self.assertEqual(-1, grbm.quasi_objective(s1, s3))

def test_quasi_objective_gradient_hidden_units(self):
grbm = GRBM([1, 2, 3],
[(1, 2), (1, 3), (2, 3)],
[1],
{1: 0.2, 2: 0.2, 3: 0.3},
{(1, 2): 0.2, (1, 3): 0.3, (2, 3): 0.6})
# Note : In the digram bellow linear biases are shown using <>
# quadratic biases using ()
# (0.2)
# Model: v1 <0.2> ----- v2 <0.2>
# \ /
# (0.3) \ / (0.6)
# \ /
# v3 <0.3>
s_observed = torch.tensor([[1.0, -1.0]])
s_model = torch.tensor([[1.0, -1.0, 1.0]])
quasi = grbm.quasi_objective(s_observed, s_model, "exact-disc")
quasi.backward()
# Compute gradients manually
# Compute unnormalized density
# h1v1 + h2v2 + h3v3 + J23v2v3 + J12v1v2 + J13v1v3
q_plus = torch.exp(-torch.tensor(0.2 + 0.2 - 0.3 - 0.6 + 0.2 - 0.3))
q_minus = torch.exp(-torch.tensor(-0.2 - 0.2 + 0.3 - 0.6 + 0.2 - 0.3))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is not sorted according to h1v1 + h2v2 + h3v3 + J23v2v3 + J12v1v2 + J13v1v3:
I think it should be
q_plus = torch.exp(-torch.tensor(0.2 + 0.2 - 0.3 - 0.6 + 0.2 - 0.3))
q_minus = torch.exp(-torch.tensor(-0.2 + 0.2 - 0.3 - 0.6 - 0.2 + 0.3))
this won't change the results

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch, thanks!

# Normalize it
z_cond = q_plus + q_minus
p_plus = q_plus / z_cond
p_minus = q_minus / z_cond
# t = sufficient statistics = (v1 v2 v3 v1v2 v1v3 v2v3)
t_plus = torch.tensor([1, 1, -1, 1, -1, -1]).float()
t_minus = torch.tensor([-1, 1, -1, -1, 1, -1]).float()
t_model = torch.tensor([1, -1, 1, -1, 1, -1]).float()
# Compute expected stat
t_cond = t_plus*p_plus + t_minus*p_minus
grad = t_cond - t_model
grad_auto = torch.cat([grbm.linear.grad, grbm.quadratic.grad])
# NOTE: this test relied on the hidden units being disconnected. This assumption gives rise
# to linearity in expectation of sufficient statistics, i.e., average spin, then calculating
# the sufficient statistics of the average spins.
torch.testing.assert_close(grad, grad_auto)


if __name__ == "__main__":
unittest.main()