Skip to content

Commit d6ebf5a

Browse files
authored
added note and assert that sbvm conc < 10k (#3412)
1 parent b34ba6c commit d6ebf5a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

pyro/distributions/sine_bivariate_von_mises.py

+12
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class SineBivariateVonMises(TorchDistribution):
5555
.. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for
5656
latent variables.
5757
58+
.. note:: Normalization remains accurate up to concentrations of 10,000.
59+
5860
** References: **
5961
1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002)
6062
2. Protein Bioinformatics and Mixtures of Bivariate von Mises Distributions for Angular Data,
@@ -108,6 +110,16 @@ def __init__(
108110
) = broadcast_all(
109111
phi_loc, psi_loc, phi_concentration, psi_concentration, correlation
110112
)
113+
114+
max_conc = torch.maximum(
115+
torch.max(phi_concentration), torch.max(psi_concentration)
116+
)
117+
assrt_hstr = (
118+
"Normalization of SineBiviateVonMises is inaccurate for"
119+
f"current max concentration ({max_conc} > 10,000)."
120+
)
121+
assert max_conc <= torch.tensor(10_000.0), assrt_hstr
122+
111123
self.phi_loc = phi_loc
112124
self.psi_loc = psi_loc
113125
self.phi_concentration = phi_concentration

tests/distributions/test_sine_bivariate_von_mises.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,15 @@ def guide(data):
132132
assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)
133133

134134

135-
@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
135+
@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10_000.0, 10_001.0])
136136
def test_sine_bivariate_von_mises_norm(conc):
137+
if conc > 10_000.0:
138+
try:
139+
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
140+
pytest.fail()
141+
except AssertionError:
142+
return
143+
137144
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
138145
num_samples = 500
139146
x = torch.linspace(-torch.pi, torch.pi, num_samples)

0 commit comments

Comments
 (0)