Skip to content

Commit 06f21e0

Browse files
Added KF with sparse sites and related test (#18)
* Added KF with sparse sites and related test * Update tests/integration/test_kalman_filter_with_sparse_sites.py Co-authored-by: Vincent Adam <[email protected]> * Incorporated suggested changes * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam <[email protected]> * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam <[email protected]> * Incorporated PR changes * sparse observations saved * passing sparse sites rather than dense sites Co-authored-by: Vincent Adam <[email protected]>
1 parent 08dac0b commit 06f21e0

File tree

2 files changed

+188
-2
lines changed

2 files changed

+188
-2
lines changed

markovflow/kalman_filter.py

+119-2
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ def _k_inv_post(self):
9696
# The emission matrix is tiled across the time_points, so for a time invariant matrix
9797
# this is equivalent to Gᵀ Σ⁻¹ G = (I_N ⊗ HᵀR⁻¹H),
9898
likelihood_precision = SymmetricBlockTriDiagonal(h_t_r_h)
99-
_k_inv_prior = self.prior_ssm.precision
10099
# K⁻¹ + GᵀΣ⁻¹G
101-
return _k_inv_prior + likelihood_precision
100+
return self._k_inv_prior + likelihood_precision
102101

103102
@property
104103
def _log_det_observation_precision(self):
@@ -495,3 +494,121 @@ def _log_det_observation_precision(self):
495494
def observations(self):
496495
""" Observation vector """
497496
return self.sites.means
497+
498+
499+
@tf_scope_class_decorator
500+
class KalmanFilterWithSparseSites(BaseKalmanFilter):
501+
r"""
502+
Performs a Kalman filter on a :class:`~markovflow.state_space_model.StateSpaceModel`
503+
and :class:`~markovflow.emission_model.EmissionModel`, with Gaussian sites, over a time grid.
504+
"""
505+
506+
def __init__(self, state_space_model: StateSpaceModel, emission_model: EmissionModel, sites: GaussianSites,
507+
num_grid_points: int, observations_index: tf.Tensor, observations: tf.Tensor):
508+
"""
509+
:param state_space_model: Parameterises the latent chain.
510+
:param emission_model: Maps the latent chain to the observations.
511+
:param sites: Gaussian sites over the observations.
512+
:param num_grid_points: number of grid points.
513+
:param observations_index: Index of the observations in the time grid with shape (N,).
514+
:param observations: Sparse observations with shape (N, output_dim).
515+
"""
516+
self.sites = sites
517+
self.observations_index = observations_index
518+
self.sparse_observations = observations
519+
self.grid_shape = tf.TensorShape((num_grid_points, 1))
520+
super().__init__(state_space_model, emission_model)
521+
522+
@property
523+
def _r_inv(self):
524+
"""
525+
Precisions of the observation model over the time grid.
526+
"""
527+
data_sites_precision = self.sites.precisions
528+
return self.sparse_to_dense(data_sites_precision, output_shape=self.grid_shape + (1,))
529+
530+
@property
531+
def _log_det_observation_precision(self):
532+
"""
533+
Sum of log determinant of the precisions of the observation model. It only calculates for the data_sites as
534+
other sites precision is anyways zero.
535+
"""
536+
return tf.reduce_sum(tf.linalg.logdet(self._r_inv_data), axis=-1)
537+
538+
@property
539+
def observations(self):
540+
""" Sparse observation vector """
541+
return self.sparse_observations
542+
543+
@property
544+
def _r_inv_data(self):
545+
"""
546+
Precisions of the observation model for only the data sites.
547+
"""
548+
return self.sites.precisions
549+
550+
def sparse_to_dense(self, tensor: tf.Tensor, output_shape: tf.TensorShape) -> tf.Tensor:
551+
"""
552+
Convert a sparse tensor to a dense one on the basis of observations index, output tensor is of the output_shape.
553+
"""
554+
return tf.scatter_nd(self.observations_index, tensor, output_shape)
555+
556+
def dense_to_sparse(self, tensor: tf.Tensor) -> tf.Tensor:
557+
"""
558+
Convert a dense tensor to a sparse one on the basis of observations index.
559+
"""
560+
tensor_shape = tensor.shape
561+
expand_dims = len(tensor_shape) == 3
562+
563+
tensor = tf.gather_nd(tf.reshape(tensor, (-1, 1)), self.observations_index)
564+
if expand_dims:
565+
tensor = tf.expand_dims(tensor, axis=-1)
566+
return tensor
567+
568+
def log_likelihood(self) -> tf.Tensor:
569+
r"""
570+
Construct a TensorFlow function to compute the likelihood.
571+
572+
For more mathematical details, look at the log_likelihood function of the parent class.
573+
The main difference from the parent class are that the vector of observations is now sparse.
574+
575+
:return: The likelihood as a scalar tensor (we sum over the `batch_shape`).
576+
"""
577+
# K⁻¹ + GᵀΣ⁻¹G = LLᵀ.
578+
l_post = self._k_inv_post.cholesky
579+
num_data = self.observations_index.shape[0]
580+
581+
# Hμ [..., num_transitions + 1, output_dim]
582+
marginal = self.emission.project_state_to_f(self.prior_ssm.marginal_means)
583+
584+
# y = obs - Hμ [..., num_transitions + 1, output_dim]
585+
disp = self.sparse_to_dense(self.observations, marginal.shape) - marginal
586+
disp_data = self.sparse_observations - self.dense_to_sparse(marginal)
587+
588+
# cst is the constant term for a gaussian log likelihood
589+
cst = (
590+
-0.5 * np.log(2 * np.pi) * tf.cast(self.emission.output_dim * num_data, default_float())
591+
)
592+
593+
term1 = -0.5 * tf.reduce_sum(
594+
input_tensor=tf.einsum("...op,...p,...o->...o", self._r_inv_data, disp_data, disp_data), axis=[-1, -2]
595+
)
596+
597+
# term 2 is: ½|L⁻¹(GᵀΣ⁻¹)y|²
598+
# (GᵀΣ⁻¹)y [..., num_transitions + 1, state_dim]
599+
obs_proj = self._back_project_y_to_state(disp)
600+
601+
# ½|L⁻¹(GᵀΣ⁻¹)y|² [...]
602+
term2 = 0.5 * tf.reduce_sum(
603+
input_tensor=tf.square(l_post.solve(obs_proj, transpose_left=False)), axis=[-1, -2]
604+
)
605+
606+
## term 3 is: ½log |K⁻¹| - log |L| + ½ log |Σ⁻¹|
607+
# where log |Σ⁻¹| = num_data * log|R⁻¹|
608+
term3 = (
609+
0.5 * self.prior_ssm.log_det_precision()
610+
- l_post.abs_log_det()
611+
+ 0.5 * self._log_det_observation_precision
612+
)
613+
614+
return tf.reduce_sum(cst + term1 + term2 + term3)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import pytest
4+
from gpflow.config import default_float
5+
6+
from markovflow.kernels.matern import Matern12
7+
from markovflow.mean_function import LinearMeanFunction
8+
from markovflow.models.gaussian_process_regression import GaussianProcessRegression
9+
from markovflow.kalman_filter import KalmanFilterWithSparseSites, UnivariateGaussianSitesNat
10+
from markovflow.likelihoods import MultivariateGaussian
11+
12+
@pytest.fixture(
13+
name="time_step_homogeneous", params=[(0.01, True), (0.01, False), (0.001, True), (0.001, False)],
14+
)
15+
def _time_step_homogeneous_fixture(request):
16+
return request.param
17+
18+
19+
@pytest.fixture(name="kalman_gpr_setup")
20+
def _setup(batch_shape, time_step_homogeneous):
21+
"""
22+
Create a Gaussian Process model and an equivalent kalman filter model
23+
with more latent states than observations.
24+
FIXME: Currently batch_shape isn't used.
25+
"""
26+
dt, homogeneous = time_step_homogeneous
27+
28+
time_grid = np.arange(0.0, 1.0, dt)
29+
if not homogeneous:
30+
time_grid = np.sort(np.random.choice(time_grid, 50, replace=False))
31+
32+
time_points = time_grid[::10]
33+
observations = np.sin(12 * time_points[..., None]) + np.random.randn(len(time_points), 1) * 0.1
34+
35+
input_data = (
36+
tf.constant(time_points, dtype=default_float()),
37+
tf.constant(observations, dtype=default_float()),
38+
)
39+
40+
observation_covariance = 1.0 # Same as GPFlow default
41+
kernel = Matern12(lengthscale=1.0, variance=1.0, output_dim=observations.shape[-1])
42+
kernel.set_state_mean(tf.random.normal((1,), dtype=default_float()))
43+
gpr_model = GaussianProcessRegression(
44+
input_data=input_data,
45+
kernel=kernel,
46+
mean_function=LinearMeanFunction(1.1),
47+
chol_obs_covariance=tf.constant([[np.sqrt(observation_covariance)]], dtype=default_float()),
48+
)
49+
50+
prior_ssm = kernel.state_space_model(time_grid)
51+
emission_model = kernel.generate_emission_model(time_grid)
52+
observations_index = tf.where(tf.equal(time_grid[..., None], time_points))[:, 0][..., None]
53+
54+
observations -= gpr_model.mean_function(time_points)
55+
56+
nat1 = observations / observation_covariance
57+
nat2 = (-0.5 / observation_covariance) * tf.ones_like(nat1)[..., None]
58+
lognorm = tf.zeros_like(nat1)
59+
sites = UnivariateGaussianSitesNat(nat1=nat1, nat2=nat2, log_norm=lognorm)
60+
61+
kf_sparse_sites = KalmanFilterWithSparseSites(prior_ssm, emission_model, sites, time_grid.shape[0],
62+
observations_index, observations)
63+
64+
return gpr_model, kf_sparse_sites
65+
66+
def test_kalman_loglikelihood(with_tf_random_seed, kalman_gpr_setup):
67+
gpr_model, kf_sparse_sites = kalman_gpr_setup
68+
69+
np.testing.assert_allclose(gpr_model.log_likelihood(), kf_sparse_sites.log_likelihood())

0 commit comments

Comments
 (0)