@@ -35,7 +35,6 @@ class SineBivariateVonMises(TorchDistribution):
35
35
This distribution is a submodel of the Bivariate von Mises distribution, called the Sine Distribution [2] in
36
36
directional statistics.
37
37
38
-
39
38
This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
40
39
To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that
41
40
avoid parameterizations where the distribution becomes bimodal; see note below.
@@ -44,10 +43,12 @@ class SineBivariateVonMises(TorchDistribution):
44
43
45
44
.. math::
46
45
47
- \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1
46
+ \frac{\rho^2 }{\kappa_1\kappa_2} \rightarrow 1
48
47
49
- because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation`
50
- parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1].
48
+ because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
49
+ `weighted_correlation` parameter with a skew away from one (e.g.,
50
+ `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
51
+ should be in [-1,1].
51
52
52
53
.. note:: The correlation and weighted_correlation params are mutually exclusive.
53
54
@@ -65,7 +66,7 @@ class SineBivariateVonMises(TorchDistribution):
65
66
:param torch.Tensor psi_concentration: concentration of second angle
66
67
:param torch.Tensor correlation: correlation between the two angles
67
68
:param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc)
68
- to avoid bimodality (see note). The `weighted_correlation` should be in [0 ,1].
69
+ to avoid bimodality (see note). The `weighted_correlation` should be in [-1 ,1].
69
70
"""
70
71
71
72
arg_constraints = {
@@ -139,7 +140,13 @@ def norm_const(self):
139
140
+ m * torch .log ((corr ** 2 ).clamp (min = tiny ))
140
141
- m * torch .log (4 * torch .prod (conc , dim = - 1 ))
141
142
)
142
- fs += log_I1 (m .max (), conc , 51 ).sum (- 1 )
143
+ num_I1terms = torch .maximum (
144
+ torch .tensor (501 ),
145
+ torch .max (self .phi_concentration ) + torch .max (self .psi_concentration ),
146
+ ).int ()
147
+
148
+ fs += log_I1 (m .max (), conc , num_I1terms ).sum (- 1 )
149
+
143
150
mfs = fs .max ()
144
151
norm_const = 2 * torch .log (torch .tensor (2 * pi )) + mfs + (fs - mfs ).logsumexp (0 )
145
152
return norm_const .reshape (self .phi_loc .shape )
0 commit comments