Skip to content

Commit b34ba6c

Browse files
authored
Fixed norm const for SBVM (#3411)
* added lognorm terms for high conc sbvm * lint
1 parent 455f7b3 commit b34ba6c

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

pyro/distributions/sine_bivariate_von_mises.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class SineBivariateVonMises(TorchDistribution):
3535
This distribution is a submodel of the Bivariate von Mises distribution, called the Sine Distribution [2] in
3636
directional statistics.
3737
38-
3938
This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
4039
To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that
4140
avoid parameterizations where the distribution becomes bimodal; see note below.
@@ -44,10 +43,12 @@ class SineBivariateVonMises(TorchDistribution):
4443
4544
.. math::
4645
47-
\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1
46+
\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1
4847
49-
because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation`
50-
parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1].
48+
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
49+
`weighted_correlation` parameter with a skew away from one (e.g.,
50+
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
51+
should be in [-1,1].
5152
5253
.. note:: The correlation and weighted_correlation params are mutually exclusive.
5354
@@ -65,7 +66,7 @@ class SineBivariateVonMises(TorchDistribution):
6566
:param torch.Tensor psi_concentration: concentration of second angle
6667
:param torch.Tensor correlation: correlation between the two angles
6768
:param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc)
68-
to avoid bimodality (see note). The `weighted_correlation` should be in [0,1].
69+
to avoid bimodality (see note). The `weighted_correlation` should be in [-1,1].
6970
"""
7071

7172
arg_constraints = {
@@ -139,7 +140,13 @@ def norm_const(self):
139140
+ m * torch.log((corr**2).clamp(min=tiny))
140141
- m * torch.log(4 * torch.prod(conc, dim=-1))
141142
)
142-
fs += log_I1(m.max(), conc, 51).sum(-1)
143+
num_I1terms = torch.maximum(
144+
torch.tensor(501),
145+
torch.max(self.phi_concentration) + torch.max(self.psi_concentration),
146+
).int()
147+
148+
fs += log_I1(m.max(), conc, num_I1terms).sum(-1)
149+
143150
mfs = fs.max()
144151
norm_const = 2 * torch.log(torch.tensor(2 * pi)) + mfs + (fs - mfs).logsumexp(0)
145152
return norm_const.reshape(self.phi_loc.shape)

tests/distributions/test_sine_bivariate_von_mises.py

+13
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,16 @@ def guide(data):
130130
) # k == 'corr'
131131

132132
assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)
133+
134+
135+
@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
136+
def test_sine_bivariate_von_mises_norm(conc):
137+
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
138+
num_samples = 500
139+
x = torch.linspace(-torch.pi, torch.pi, num_samples)
140+
y = torch.linspace(-torch.pi, torch.pi, num_samples)
141+
mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1)
142+
integral_torus = (
143+
torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2
144+
).sum()
145+
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)

0 commit comments

Comments
 (0)