diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 90aee25..92b5c60 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -396,6 +396,11 @@ def quasi_objective( 'each other.' ) raise ValueError(err_msg) + # NOTE: this method relies on hidden units being disconnected. The calculations + # depend on this assumption in **two** ways. The obvious one is marginalization. The + # less obvious dependence is the linearity of expectation and sufficient statistics. + # Because hidden units are disconnected, we can average their spins before computing + # the sufficient statistics, which is then passed into the quasi objective function. obs = self._compute_expectation_disconnected(s_observed) elif kind == "sampling": obs = self._approximate_expectation_sampling( @@ -436,10 +441,10 @@ def _compute_effective_field(self, padded: torch.Tensor) -> torch.Tensor: # by the corresponding edges. Transforming this contribution vector by a # cumulative sum yields cumulative contributions to effective fields. # Differencing removes the extra gobbledygook. - contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx] + contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx].detach() cumulative_contribution = contribution.cumsum(1) # Don't forget to add the linear fields! - h_eff = self._linear[self.hidden_idx] + cumulative_contribution[ + h_eff = self._linear[self.hidden_idx].detach() + cumulative_contribution[ :, self._bin_idx ].diff(dim=1, prepend=torch.zeros(bs, device=padded.device).unsqueeze(1)) diff --git a/releasenotes/notes/fix-grad-with-hidden-units-e74349975573db02.yaml b/releasenotes/notes/fix-grad-with-hidden-units-e74349975573db02.yaml new file mode 100755 index 0000000..fbd5f66 --- /dev/null +++ b/releasenotes/notes/fix-grad-with-hidden-units-e74349975573db02.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fix the (automatic) gradient computation when hidden units are present. The issue was that the + parameters, linear and quadratic weights, were used in the marginalization without being + detached from the computation graph. The fix was to detach the parameters when computing + effective fields. \ No newline at end of file diff --git a/tests/test_boltzmann_machine.py b/tests/test_boltzmann_machine.py index 5881b3d..f4d2ca0 100644 --- a/tests/test_boltzmann_machine.py +++ b/tests/test_boltzmann_machine.py @@ -168,7 +168,7 @@ def test_compute_effective_field_unordered(self): # = 0.13 * [-1] - 0.17 * [1] + 0.4 = .1 grbm._linear.data = torch.tensor([-0.1, -0.2, 0.4, 0.2]) - grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17 ]) + grbm._quadratic.data = torch.tensor([-.15, -0.7, 0.15, 0.13, -0.17]) padded = torch.tensor([[-1.0, float("nan"), float("nan"), 1.0]]) h_eff = grbm._compute_effective_field(padded) self.assertTrue(torch.allclose(h_eff.data, torch.tensor([-0.5000, 0.1000]), atol=1e-6)) @@ -322,7 +322,7 @@ def test_sample_return_sampleset(self): self.assertEqual(4, len(sampleset.variables)) self.assertEqual(set(grbm.nodes), set(sampleset.variables)) - def test_objective(self): + def test_quasi_objective(self): # Create a triangle graph with an additional dangling vertex self.nodes = list("abcd") self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] @@ -356,6 +356,46 @@ def test_objective(self): self.assertEqual(-1, grbm.quasi_objective(s1, s2).item()) self.assertEqual(-1, grbm.quasi_objective(s1, s3)) + def test_quasi_objective_gradient_hidden_units(self): + grbm = GRBM([1, 2, 3], + [(1, 2), (1, 3), (2, 3)], + [1], + {1: 0.2, 2: 0.2, 3: 0.3}, + {(1, 2): 0.2, (1, 3): 0.3, (2, 3): 0.6}) + # Note : In the digram bellow linear biases are shown using <> + # quadratic biases using () + # (0.2) + # Model: v1 <0.2> ----- v2 <0.2> + # \ / + # (0.3) \ / (0.6) + # \ / + # v3 <0.3> + s_observed = torch.tensor([[1.0, -1.0]]) + s_model = torch.tensor([[1.0, -1.0, 1.0]]) + quasi = grbm.quasi_objective(s_observed, s_model, "exact-disc") + quasi.backward() + # Compute gradients manually + # Compute unnormalized density + # h1v1 + h2v2 + h3v3 + J23v2v3 + J12v1v2 + J13v1v3 + q_plus = torch.exp(-torch.tensor(0.2 + 0.2 - 0.3 - 0.6 + 0.2 - 0.3)) + q_minus = torch.exp(-torch.tensor(-0.2 - 0.2 + 0.3 - 0.6 + 0.2 - 0.3)) + # Normalize it + z_cond = q_plus + q_minus + p_plus = q_plus / z_cond + p_minus = q_minus / z_cond + # t = sufficient statistics = (v1 v2 v3 v1v2 v1v3 v2v3) + t_plus = torch.tensor([1, 1, -1, 1, -1, -1]).float() + t_minus = torch.tensor([-1, 1, -1, -1, 1, -1]).float() + t_model = torch.tensor([1, -1, 1, -1, 1, -1]).float() + # Compute expected stat + t_cond = t_plus*p_plus + t_minus*p_minus + grad = t_cond - t_model + grad_auto = torch.cat([grbm.linear.grad, grbm.quadratic.grad]) + # NOTE: this test relied on the hidden units being disconnected. This assumption gives rise + # to linearity in expectation of sufficient statistics, i.e., average spin, then calculating + # the sufficient statistics of the average spins. + torch.testing.assert_close(grad, grad_auto) + if __name__ == "__main__": unittest.main()