@@ -96,9 +96,8 @@ def _k_inv_post(self):
96
96
# The emission matrix is tiled across the time_points, so for a time invariant matrix
97
97
# this is equivalent to Gᵀ Σ⁻¹ G = (I_N ⊗ HᵀR⁻¹H),
98
98
likelihood_precision = SymmetricBlockTriDiagonal (h_t_r_h )
99
- _k_inv_prior = self .prior_ssm .precision
100
99
# K⁻¹ + GᵀΣ⁻¹G
101
- return _k_inv_prior + likelihood_precision
100
+ return self . _k_inv_prior + likelihood_precision
102
101
103
102
@property
104
103
def _log_det_observation_precision (self ):
@@ -495,3 +494,121 @@ def _log_det_observation_precision(self):
495
494
def observations (self ):
496
495
""" Observation vector """
497
496
return self .sites .means
497
+
498
+
499
+ @tf_scope_class_decorator
500
+ class KalmanFilterWithSparseSites (BaseKalmanFilter ):
501
+ r"""
502
+ Performs a Kalman filter on a :class:`~markovflow.state_space_model.StateSpaceModel`
503
+ and :class:`~markovflow.emission_model.EmissionModel`, with Gaussian sites, over a time grid.
504
+ """
505
+
506
+ def __init__ (self , state_space_model : StateSpaceModel , emission_model : EmissionModel , sites : GaussianSites ,
507
+ num_grid_points : int , observations_index : tf .Tensor , observations : tf .Tensor ):
508
+ """
509
+ :param state_space_model: Parameterises the latent chain.
510
+ :param emission_model: Maps the latent chain to the observations.
511
+ :param sites: Gaussian sites over the observations.
512
+ :param num_grid_points: number of grid points.
513
+ :param observations_index: Index of the observations in the time grid with shape (N,).
514
+ :param observations: Sparse observations with shape (N, output_dim).
515
+ """
516
+ self .sites = sites
517
+ self .observations_index = observations_index
518
+ self .sparse_observations = observations
519
+ self .grid_shape = tf .TensorShape ((num_grid_points , 1 ))
520
+ super ().__init__ (state_space_model , emission_model )
521
+
522
+ @property
523
+ def _r_inv (self ):
524
+ """
525
+ Precisions of the observation model over the time grid.
526
+ """
527
+ data_sites_precision = self .sites .precisions
528
+ return self .sparse_to_dense (data_sites_precision , output_shape = self .grid_shape + (1 ,))
529
+
530
+ @property
531
+ def _log_det_observation_precision (self ):
532
+ """
533
+ Sum of log determinant of the precisions of the observation model. It only calculates for the data_sites as
534
+ other sites precision is anyways zero.
535
+ """
536
+ return tf .reduce_sum (tf .linalg .logdet (self ._r_inv_data ), axis = - 1 )
537
+
538
+ @property
539
+ def observations (self ):
540
+ """ Sparse observation vector """
541
+ return self .sparse_observations
542
+
543
+ @property
544
+ def _r_inv_data (self ):
545
+ """
546
+ Precisions of the observation model for only the data sites.
547
+ """
548
+ return self .sites .precisions
549
+
550
+ def sparse_to_dense (self , tensor : tf .Tensor , output_shape : tf .TensorShape ) -> tf .Tensor :
551
+ """
552
+ Convert a sparse tensor to a dense one on the basis of observations index, output tensor is of the output_shape.
553
+ """
554
+ return tf .scatter_nd (self .observations_index , tensor , output_shape )
555
+
556
+ def dense_to_sparse (self , tensor : tf .Tensor ) -> tf .Tensor :
557
+ """
558
+ Convert a dense tensor to a sparse one on the basis of observations index.
559
+ """
560
+ tensor_shape = tensor .shape
561
+ expand_dims = len (tensor_shape ) == 3
562
+
563
+ tensor = tf .gather_nd (tf .reshape (tensor , (- 1 , 1 )), self .observations_index )
564
+ if expand_dims :
565
+ tensor = tf .expand_dims (tensor , axis = - 1 )
566
+ return tensor
567
+
568
+ def log_likelihood (self ) -> tf .Tensor :
569
+ r"""
570
+ Construct a TensorFlow function to compute the likelihood.
571
+
572
+ For more mathematical details, look at the log_likelihood function of the parent class.
573
+ The main difference from the parent class are that the vector of observations is now sparse.
574
+
575
+ :return: The likelihood as a scalar tensor (we sum over the `batch_shape`).
576
+ """
577
+ # K⁻¹ + GᵀΣ⁻¹G = LLᵀ.
578
+ l_post = self ._k_inv_post .cholesky
579
+ num_data = self .observations_index .shape [0 ]
580
+
581
+ # Hμ [..., num_transitions + 1, output_dim]
582
+ marginal = self .emission .project_state_to_f (self .prior_ssm .marginal_means )
583
+
584
+ # y = obs - Hμ [..., num_transitions + 1, output_dim]
585
+ disp = self .sparse_to_dense (self .observations , marginal .shape ) - marginal
586
+ disp_data = self .sparse_observations - self .dense_to_sparse (marginal )
587
+
588
+ # cst is the constant term for a gaussian log likelihood
589
+ cst = (
590
+ - 0.5 * np .log (2 * np .pi ) * tf .cast (self .emission .output_dim * num_data , default_float ())
591
+ )
592
+
593
+ term1 = - 0.5 * tf .reduce_sum (
594
+ input_tensor = tf .einsum ("...op,...p,...o->...o" , self ._r_inv_data , disp_data , disp_data ), axis = [- 1 , - 2 ]
595
+ )
596
+
597
+ # term 2 is: ½|L⁻¹(GᵀΣ⁻¹)y|²
598
+ # (GᵀΣ⁻¹)y [..., num_transitions + 1, state_dim]
599
+ obs_proj = self ._back_project_y_to_state (disp )
600
+
601
+ # ½|L⁻¹(GᵀΣ⁻¹)y|² [...]
602
+ term2 = 0.5 * tf .reduce_sum (
603
+ input_tensor = tf .square (l_post .solve (obs_proj , transpose_left = False )), axis = [- 1 , - 2 ]
604
+ )
605
+
606
+ ## term 3 is: ½log |K⁻¹| - log |L| + ½ log |Σ⁻¹|
607
+ # where log |Σ⁻¹| = num_data * log|R⁻¹|
608
+ term3 = (
609
+ 0.5 * self .prior_ssm .log_det_precision ()
610
+ - l_post .abs_log_det ()
611
+ + 0.5 * self ._log_det_observation_precision
612
+ )
613
+
614
+ return tf .reduce_sum (cst + term1 + term2 + term3 )
0 commit comments