Hidden Markov Models
Hidden Markov Models (HMMs) are latent variable models for sequential data. Like the mixture models from the previous chapter, HMMs have discrete latent states. Unlike mixture models, the discrete latent states of an HMM are not independent: the state at time
Recall the basic Gaussian mixture model,
where
-
$z_t \in {1,\ldots,K}$ is a latent mixture assignment -
$x_t \in \reals^D$ is an observed data point -
$\mbpi\in \Delta_{K-1}$ ,$\mbmu_k \in \reals^D$ , and$\mbSigma_k \in \reals_{\succeq 0}^{D \times D}$ are parameters
(Here we've switched to indexing data points by
Let
::: {admonition} Exercise :class: tip Draw the graphical model for a GMM. :::
Recall the EM algorithm for mixture models,
-
E step: Compute the posterior distribution
$$ \begin{aligned} q(\mbz_{1:T}) &= p(\mbz_{1:T} \mid \mbx_{1:T}; \mbTheta) \ &= \prod_{t=1}^T p(z_t \mid \mbx_t; \mbTheta) \ &= \prod_{t=1}^T q_t(z_t) \end{aligned} $$
-
M step: Maximize the ELBO wrt
$\mbTheta$ ,$$ \begin{aligned} \mathcal{L}(\mbTheta) &= \mathbb{E}{q(\mbz{1:T})}\left[\log p(\mbx_{1:T}, \mbz_{1:T}; \mbTheta) - \log q(\mbz_{1:T}) \right] \ &= \mathbb{E}{q(\mbz{1:T})}\left[\log p(\mbx_{1:T}, \mbz_{1:T}; \mbTheta) \right] + c. \end{aligned} $$
For exponential family mixture models, the M-step only requires expected sufficient statistics.
Hidden Markov Models
Hidden Markov Models (HMMs) are like mixture models with temporal dependencies between the mixture assignments.
This graphical model says that the joint distribution factors as,
We call this an HMM because the hidden states follow a Markov chain,
An HMM consists of three components:
-
Initial distribution:
$z_1 \sim \mathrm{Cat}(\mbpi_0)$ -
Transition matrix:
$z_t \sim \mathrm{Cat}(\mbP_{z_{t-1}})$ where$\mbP\in [0,1]^{K \times K}$ is a row-stochastic transition matrix with rows$\mbP_k$ . -
Emission distribution:
$\mbx_t \sim p(\cdot \mid \boldsymbol{\theta}_{z_t})$
We are interested in questions like:
-
What are the predictive distributions of
$p(z_{t+1} \mid \mbx_{1:t})$ ? -
What is the posterior marginal distribution
$p(z_t \mid \mbx_{1:T})$ ? -
What is the posterior pairwise marginal distribution
$p(z_t, z_{t+1} \mid \mbx_{1:T})$ ? -
What is the posterior mode
$z_{1:T}^\star = \mathop{\mathrm{arg,max}}p(z_{1:T} \mid \mbx_{1:T})$ ? -
How can we sample the posterior
$p(\mbz_{1:T} \mid \mbx_{1:T})$ of an HMM? -
What is the marginal likelihood
$p(\mbx_{1:T})$ ? -
How can we learn the parameters of an HMM?
:::{admonition} Exercise :class: tip On the surface, what makes these inference problems harder than in the simple mixture model case? :::
The predictive distributions give the probability of the latent state
We call
We can also write these recursions in a vectorized form. Let
both be vectors in
where
Finally, to get the predictive distributions we just have to normalize,
:::{admonition} Question :class: tip What does the normalizing constant tell us? :::
The posterior marginal
distributions give the probability of the latent state
where we have introduced the backward messages
The backward messages can be computed recursively too,
For the base case, let
Let
be a vector in
$$ \begin{aligned} \boldsymbol{\beta}{t} &= \mbP(\boldsymbol{\beta}{t+1} \odot \mbl_{t+1}). \end{aligned} $$
Let
Now we have everything we need to compute the posterior marginal,
We just derived the forward-backward algorithm for HMMs!
:::{admonition} Exercise
:class: tip
If the forward
messages represent the predictive probabilities
:::{admonition} Exercise
:class: tip
Use the forward
and backward messages to compute the posterior pairwise marginals
If you're working with long time series, especially if you're working with 32-bit floating point, you need to be careful.
The messages involve products of probabilities, which can quickly overflow.
There's a simple fix though: after each step, re-normalize the messages so that they sum to one. I.e replace
with
$$
\begin{aligned}
\overline{\mbalpha}{t+1} &= \frac{1}{A_t} \mbP^\top (\overline{\mbalpha}t \odot \mbl_t) \
A_t &= \sum{k=1}^K \sum{j=1}^K P_{jk} \overline{\alpha}{t,j} l{t,j}
\equiv \sum_{j=1}^K \overline{\alpha}{t,j} l{t,j} \quad \text{(since
This leads to a nice interpretation: The normalized
messages are predictive likelihoods
$\overline{\alpha}{t+1,k} = p(z{t+1}=k \mid \mbx_{1:t})$,
and the normalizing constants are
EM for Hidden Markov Models
Now we can put it all together. To perform EM in an HMM,
-
E step: Compute the posterior distribution \begin{align*} q(\mbz_{1:T}) &= p(\mbz_{1:T} \mid \mbx_{1:T}; \mbTheta). \end{align*}
(Really, run the forward-backward algorithm to get posterior marginals and pairwise marginals.)
-
M step: Maximize the ELBO wrt
$\mbTheta$ ,\begin{align*} \mathcal{L}(\mbTheta) &= \mathbb{E}{q(\mbz{1:T})}\left[\log p(\mbx_{1:T}, \mbz_{1:T}; \mbTheta) \right] + c \ &= \nonumber \mathbb{E}{q(\mbz{1:T})}\left[\sum_{k=1}^K \mathbb{I}[z_1=k]\log \pi_{0,k} \right] + \mathbb{E}{q(\mbz{1:T})}\left[\sum_{t=1}^{T-1} \sum_{i=1}^K \sum_{j=1}^K \mathbb{I}[z_t=i, z_{t+1}=j]\log P_{i,j} \right] \ &\qquad + \mathbb{E}{q(\mbz{1:T})}\left[\sum_{t=1}^T \sum_{k=1}^K \mathbb{I}[z_t=k]\log p(\mbx_t; \theta_k) \right] \end{align*}
For exponential family observations, the M-step only requires expected sufficient statistics.
:::{admonition} Questions :class: tip
- How can we sample the posterior?
- How can we find the posterior mode?
- How can we choose the number of states?
- What if my transition matrix is sparse? :::
HMMs add temporal dependencies to the latent states of mixture models. They're a simple yet powerful model for sequential data.
The emission distribution can be extended in many ways. For example, we could include temporal dependencies in the emissions via an autoregressive HMM, or condition on external covariates as in an input-output HMM
Like mixture models, we can derive efficient stochastic EM algorithms for HMMs, which keep rolling averages of sufficient statistics across mini-batches (e.g., individual trials from a collection of sequences).
It's always good to implement models and algorithms yourself at least once, like you will in Homework 3. Going forward, you should check out our implementations in Dynamax and SSM!