@@ -393,9 +393,16 @@ def particle_covariance_mtx(cls, weights, locations):
393
393
# positive-semidefinite covariance matrix. If a negative eigenvalue
394
394
# is produced, we should warn the caller of this.
395
395
assert np .all (np .isfinite (cov ))
396
- if not np .all (la .eig (cov )[0 ] >= 0 ):
397
- warnings .warn ('Numerical error in covariance estimation causing positive semidefinite violation.' , ApproximationWarning )
398
-
396
+ vals , vecs = la .eig (cov )
397
+ small_vals = abs (vals ) < 1e-12
398
+ vals [small_vals ] = 0
399
+ if not np .all (vals >= 0 ):
400
+ warnings .warn (
401
+ 'Numerical error in covariance estimation causing positive semidefinite violation.' ,
402
+ ApproximationWarning
403
+ )
404
+ if np .any (small_vals ):
405
+ return (vecs * vals ) @ vecs .T .conj ()
399
406
return cov
400
407
401
408
def est_mean (self ):
@@ -1007,12 +1014,12 @@ def n_rvs(self):
1007
1014
1008
1015
def sample (self , n = 1 ):
1009
1016
return self .dist .rvs (size = n )[:, np .newaxis ]
1010
-
1017
+
1011
1018
class DirichletDistribution (Distribution ):
1012
1019
r"""
1013
1020
The dirichlet distribution, whose pdf at :math:`x` is proportional to
1014
1021
:math:`\prod_i x_i^{\alpha_i-1}`.
1015
-
1022
+
1016
1023
:param alpha: The list of concentration parameters.
1017
1024
"""
1018
1025
def __init__ (self , alpha ):
@@ -1021,7 +1028,7 @@ def __init__(self, alpha):
1021
1028
raise ValueError ('The input alpha must be a 1D list of concentration parameters.' )
1022
1029
1023
1030
self ._dist = st .dirichlet (alpha = self .alpha )
1024
-
1031
+
1025
1032
@property
1026
1033
def alpha (self ):
1027
1034
return self ._alpha
0 commit comments