diff --git a/pymc_extras/statespace/filters/kalman_filter.py b/pymc_extras/statespace/filters/kalman_filter.py index cf9ac4aa3..459ae91b8 100644 --- a/pymc_extras/statespace/filters/kalman_filter.py +++ b/pymc_extras/statespace/filters/kalman_filter.py @@ -393,7 +393,7 @@ def predict(a, P, c, T, R, Q) -> tuple[TensorVariable, TensorVariable]: .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods. 2nd ed, Oxford University Press, 2012. """ - a_hat = T.dot(a) + c + a_hat = T @ a + c P_hat = quad_form_sym(T, P) + quad_form_sym(R, Q) return a_hat, P_hat @@ -580,16 +580,16 @@ def update(self, a, P, y, d, Z, H, all_nan_flag): .. [1] Durbin, J., and S. J. Koopman. Time Series Analysis by State Space Methods. 2nd ed, Oxford University Press, 2012. """ - y_hat = d + Z.dot(a) + y_hat = d + Z @ a v = y - y_hat - PZT = P.dot(Z.T) + PZT = P.dot(Z.mT) F = Z.dot(PZT) + stabilize(H, self.cov_jitter) - K = pt.linalg.solve(F.T, PZT.T, assume_a="pos", check_finite=False).T + K = pt.linalg.solve(F.mT, PZT.mT, assume_a="pos", check_finite=False).mT I_KZ = pt.eye(self.n_states) - K.dot(Z) - a_filtered = a + K.dot(v) + a_filtered = a + K @ v P_filtered = quad_form_sym(I_KZ, P) + quad_form_sym(K, H) F_inv_v = pt.linalg.solve(F, v, assume_a="pos", check_finite=False) @@ -630,9 +630,9 @@ def predict(self, a, P, c, T, R, Q): a_hat = T.dot(a) + c Q_chol = pt.linalg.cholesky(Q, lower=True) - M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).T + M = pt.horizontal_stack(T @ P_chol, R @ Q_chol).mT R_decomp = pt.linalg.qr(M, mode="r") - P_chol_hat = R_decomp[: self.n_states, : self.n_states].T + P_chol_hat = R_decomp[..., : self.n_states, : self.n_states].mT return a_hat, P_chol_hat @@ -665,7 +665,7 @@ def update(self, a, P, y, d, Z, H, all_nan_flag): upper = pt.horizontal_stack(H_chol, Z @ P_chol) lower = pt.horizontal_stack(zeros, P_chol) A_T = pt.vertical_stack(upper, lower) - B = pt.linalg.qr(A_T.T, mode="r").T + B = pt.linalg.qr(A_T.mT, mode="r").mT F_chol = B[: self.n_endog, : self.n_endog] K_F_chol = B[self.n_endog :, : self.n_endog] @@ -677,6 +677,7 @@ def compute_non_degenerate(P_chol_filtered, F_chol, K_F_chol, v): inner_term = solve_triangular( F_chol, solve_triangular(F_chol, v, lower=True), lower=True ) + loss = (v.T @ inner_term).ravel() # abs necessary because we're not guaranteed a positive diagonal from the schur decomposition @@ -800,7 +801,7 @@ def kalman_step(self, y, a, P, c, d, T, Z, R, H, Q): obs_cov[-1], ) - P_filtered = stabilize(0.5 * (P_filtered + P_filtered.T), self.cov_jitter) + P_filtered = stabilize(0.5 * (P_filtered + P_filtered.mT), self.cov_jitter) a_hat, P_hat = self.predict(a=a_filtered, P=P_filtered, c=c, T=T, R=R, Q=Q) ll = -0.5 * ((pt.neq(ll_inner, 0).sum()) * MVN_CONST + ll_inner.sum()) diff --git a/pymc_extras/statespace/filters/utilities.py b/pymc_extras/statespace/filters/utilities.py index d61537b62..4f87991b3 100644 --- a/pymc_extras/statespace/filters/utilities.py +++ b/pymc_extras/statespace/filters/utilities.py @@ -1,7 +1,5 @@ import pytensor.tensor as pt -from pytensor.tensor.nlinalg import matrix_dot - from pymc_extras.statespace.utils.constants import JITTER_DEFAULT, NEVER_TIME_VARYING, VECTOR_VALUED @@ -48,12 +46,11 @@ def split_vars_into_seq_and_nonseq(params, param_names): def stabilize(cov, jitter=JITTER_DEFAULT): - # Ensure diagonal is non-zero cov = cov + pt.identity_like(cov) * jitter return cov def quad_form_sym(A, B): - out = matrix_dot(A, B, A.T) - return 0.5 * (out + out.T) + out = A @ B @ A.mT + return 0.5 * (out + out.mT)