Expectation Maximization
This post covers the Expectation Maximization (EM) algorithm, a popular heuristic to (approximately) compute maximum likelihood estimates when dealing with unobserved / latent data. We will motivate and derive the EM’s inner mechanisms before making them explicit on two classical examples (Gaussian Mixture parameter estimation and Hidden Markov Model identification).
Deriving the EM
Setting and notations
We will introduce the EM through a toy example. Consider the following data generating process: $$ Z \sim p_\nu \quad \text{ and } X \sim p_\mu(z) \; . $$ Both $Z$ and $X$ are both assumed to be real-valued. The variables $\nu$ and $\mu$ which respectively parametrize the marginal distribution of $Z$ and the conditional distribution of $X$, are assumed to live in some Euclidian space. We only observe $X$ – the variable $Z$ is hidden or latent. Our goal is to estimate $\theta = (\nu, \mu)$ jointly. The Maximum Likelihood Estimator (MLE) is defined as: $$ \begin{aligned} \hat{\theta}&\in \argmax_{\theta} p(X\vert \theta)\;, & \\ &= \argmax_{\theta} \int_{z} p(X, z\vert \theta)dz\;, & (\text{law of total probability})\\ &= \argmax_{\theta} \int_{z} p_\mu(X\vert z)p_\nu(z)dz\;. \end{aligned} $$ As we’ll see in our examples, this objective is typically non-convex – in contrast to similar settings where there is no hidden variables. The Expectation Maximisation is a heuristic algorithm designed to optimize such a non-convex landscape. When it applies, it is preferred over gradient-based approaches – even though most modern libraries would combined the two.
Before moving on, we will introduce the negative log-likelihood: $$ \boxed{ \mathcal{L}(\theta) = - \log \int_{z} p_\mu(X\vert z)p_\nu(z)dz \;, } $$ so that $\hat{\theta} \in \argmin_\theta \mathcal{L}(\theta)$.
Upper-bounding the negative log-likelihood
Much like a gradient-based algorithm, the EM will produce a sequence of estimate $\{\theta_t\}_{t\geq 1}$. At every iteration $t$, it will construct based on $\theta_t$ an upper-bound on the true loss $\mathcal{L}(\theta)$. The minimizer of this upper-bound will be $\theta_{t+1}$. The goal of this section is to derive said upper-bound.
Let us define: $$ { \ell(\theta\vert\theta_t) = \mathcal{L}(\theta_t) -\int_{z} \log\left(\frac{p(X, z \vert \theta)}{p(X, z \vert \theta_t)}\right)p(z \vert X, \theta_t) dz \; . } $$ Then, $\ell(\theta\vert\theta_t)$ is a tangent upper-bound on $\mathcal{L}(\theta)$, meaning: $$ \mathcal{L}(\theta) \leq \ell(\theta\vert\theta_t) \text{ and } \mathcal{L}(\theta_t) = \ell(\theta_t\vert\theta_t)\; . $$
The EM next iterate will be $\theta_{t+1} = \text{arg min}_{\theta}\; \ell(\theta\vert\theta_{t+1})$. Observe that this ensures that: $\mathcal{L}(\theta_{t+1})\leq \mathcal{L}(\theta_{t})$ – in other words, we have guaranteed improvement. Indeed: $$ \begin{aligned} \mathcal{L}(\theta_{t+1}) &\leq \ell(\theta_{t+1} \vert \theta_t) \;, \\ &\overset{(i)}{\leq} \ell(\theta_{t} \vert \theta_t) \;, \\ &= \mathcal{L}(\theta_t)\;, \end{aligned} $$ where $(i)$ uses that $ \ell(\theta_{t+1} \vert \theta_t) = \min_\theta \ell(\theta\vert\theta_t)$.
We can now write down a very naked first description the EM:
For every $t\geq 1$: $$ \theta_{t+1} \in \argmin_\theta \ell(\theta\vert \theta_t) \; . $$
This will guarantee improvement at every iteration. Furthermore, it is rather easy to show that the sequence $\{\theta_t\}_t$ converges to a stationary point of $\mathcal{L}$. The name of this heuristic is still rather obscure; what do we say that this is Expectation Maximization?
Expectation Maximisation
Let’s dive deeper into the minimization of $ \ell(\theta\vert \theta_t)$. Start by removing constant terms w.r.t $\theta$, we obtain: $$ \begin{aligned} \argmin_\theta \ell(\theta\vert \theta_t) &= \argmin_\theta\mathcal{L}(\theta_t) -\int_{z} \log\left(\frac{p(X, z \vert \theta)}{p(X, z \vert \theta_t)}\right)p(z \vert X, \theta_t) dz \;, \\ &= \argmin_\theta -\int_{z} \log\left(p(X, z \vert \theta)\right)p(z \vert X, \theta_t) dz \;,\\ &= \argmax_\theta \left\{\mathcal{Q}(\theta\vert \theta_t):= \mathbb{E}\left[\log p(X, z \vert \theta) \middle\vert X, \theta_t \right]\right\}\;. \end{aligned} $$
The quantity $p(x, z \vert \theta)$ is usually called the complete-data likelihood, in opposition to the incomplete-data likelihood $p(x\vert \theta)$. Its conditional expectation with respect to the observations and the model estimates $\theta_t$ is denoted $\mathcal{Q}(\theta\vert \theta_t):=\mathbb{E}\left[\log p(X, z \vert \theta) \middle\vert X, \theta_t \right]$, and refered to as the auxiliary log-likelihood.
The first step in completing the EM step is to materialise the auxiliary log-likelihood, which requires computing the conditional $p(z \vert X, \theta_t)$. This is called the Expectation step. The second step consist in maximizing the auxiliary log-likelihood w.r.t $\theta$; this is the Maximization step.
For every $t\geq 1$:
[Expectation] compute $p(z \vert X, \theta_t)$ and: $$ \mathbb{E}\left[\log p(X, z \vert \theta) \middle\vert X, \theta_t \right] = \int_{z} \log\left(p(X, z \vert \theta)\right)p(z \vert X, \theta\_t) dz \; . $$ [Maximization] $$ \theta_{t+1} \in \argmax_\theta \left\{\mathcal{Q}(\theta\vert \theta_t)= \mathbb{E}\left[\log p(X, z \vert \theta) \middle\vert X, \theta_t \right] \right\}\; . $$
Fitting a Gaussian Mixture Model
Consider the data-generating process of a Gaussian Mixture Process (GMM). Formally, let $\alpha_1, \ldots, \alpha_K \in \mathbb{R}^d$ and $\Sigma_1, \ldots, \Sigma_K \succeq 0$ be, respectively, the means and covariance matrices of $K$ Gaussian random variables. A new sample $X$ from the GMM is drawn by first sampling its hidden variable $Z\in\{1, \ldots, K\}$, representing which Gaussian component will generate $X$. We then generate $X\sim \mathcal{N}(\alpha_Z, \Sigma_Z)$. Observe that the latent variable $p_\nu$ being discrete, it parameterizable by a vector living in the $K$-dimensional simplex: $$ \nu \in \Delta(K) := \{ \omega \in \mathbb{R}^K, \; \omega_i \geq 0 \, \forall i \in\{1, \ldots, K\} \,, \; \sum_{i=1}^K \omega_i = 1 \} \; , $$ so that $p_\nu(z) = \nu_z$ for any $z\in \{ 1, \ldots, K\}$.
Let $X_1, \ldots, X_n$ be $n$ i.i.d samples from the GMM. Recall that we wish to jointly estimate jointly: $$ \theta := ( \nu_1, \ldots, \nu_K, \alpha_1, \ldots, \alpha_K, \Sigma_1, \ldots, \Sigma_K) \; . $$ The log-likelihood of the model: $$ \mathcal{L}(\theta) = -\sum_{i=1}^n \log\left(\sum_{z=1}^K \nu_z \mathcal{N}(x_i \vert \alpha_z, \Sigma_z) \right) \; , $$ is easily shown to be a non-convex function of $\theta$. We therefore turn to the EM procedure, of which we now make each step explicit. In what follows, let $\theta_t$ be the EM current iteration’s estimator.
Expectation Step
We must compute, for each sample, the conditional distribution $p(z_i\vert x_i, \theta_t)$. To reduce clutter we will use the short-hand $\pi_k^i := p(z_i = k\vert x_i, \theta_t)$. Observe that: $$ \begin{aligned} \pi_k^i &= p(z_i=k\vert x_i, \theta_t) \; ,\\ &\propto p(z_i=k\vert \theta_t)p(x_i\vert z_i=k, \theta_t) &(\text{Bayes rule}) \; , \\ &\propto \nu_k^t \cdot \mathcal{N}(x_i\vert \alpha_k, \Sigma_k) \end{aligned} $$ where $\nu_k^t$ is our current mixture weight estimator for the $k$’ component (included in $\theta_t$). Normalizing the distribution yields the following expression, concluding the Expectation step: $$ \boxed{ \pi_k^i = \frac{ \nu_k^t \cdot \mathcal{N}(x_i\vert \alpha_k, \Sigma_k)} {\sum_{\ell=1}^K \nu_{\ell}^t \cdot \mathcal{N}(x_i\vert \alpha_{\ell}, \Sigma_{\ell})}\; . } $$
Maximisation step
Let us now write down the complete-data likelihood: $$ \begin{aligned} \mathcal{Q}(\theta\vert\theta_t) &= \sum_{i=1}^n \left(\sum_{k=1}^K p(z_i = k\vert x_i , \theta_t)\log p(x_i, z_i=k\vert \theta)\right) \;, \\ &= \sum_{i=1}^n \left(\sum_{k=1}^K \pi_k^i \log \left(p(x_i, z_i=k\vert \theta)\right)\right) \;, &(\text{using shorthand}) \\ &= \sum_{i=1}^n \left(\sum_{k=1}^K \pi_k^i \log p(x_i\vert z_i=k, \theta) + \pi_k^i \log p(z_i=k\vert \theta_t) \right) \;, &(\text{Bayes rule}) \\ &= \sum_{i=1}^n \sum_{k=1}^K \pi_k^i \log \mathcal{N}(x_i\vert \alpha_k, \Sigma_k) + \sum_{i=1}^n \sum_{k=1}^K \pi_k^i \log \nu_k \;. &(\text{re-arranging}) \\ \end{aligned} $$ This is a concave function of $\theta$ – hence its maximisation via gradient-based approach is indeed principled. Our luck doesn’t stop here in this case; we actually have a closed-form for $\theta_{t+1} \in \argmax \mathcal{Q}(\theta\vert \theta_t)$: $$ \boxed{ \begin{aligned} \nu_k^{t+1} &= \frac{\sum_{i=1}^n \pi_k^i}{\sum_{i=1}^n \sum_{\ell=1}^K \pi_\ell^i} \;, \\ \alpha_k^{t+1} &= \frac{1}{\sum_{i=1}^n \pi_k^i} \sum_{i=1}^n \pi_k^i x_i\;, \\ \Sigma_k^{t+1} &= \frac{1}{\sum_{i=1}^n \pi_k^i} \sum_{i=1}^n \pi_k^i (x_i - \mu_i^{t+1})(x_i - \mu_i^{t+1}) ^\top \;. \end{aligned} } $$ for all $k\in\{1, \ldots, K\}$.
Learning a finite Hidden Markov Model
In this section we consider the problem of learning the parameters of a finite HMM. Let $\mathcal{Z} = \{1, \ldots, n\}$ the state space and $\mathcal{X} = \{1, \ldots, m\}$ the observation space. Recall the HMM dynamics; for $t\geq 1$: $$ \begin{aligned} x_t &\sim p_\nu(\cdot\vert z_t)\; ,\\ z_{t+1} &\sim p_\mu(\cdot\vert z_t) \;, \\ \end{aligned} $$ where $p_\nu$ describes the state’s dynamics and $p_\mu$ the observation kernel. For simplicity, we assume that we have access to only one trajectory of observations $(x_1, \ldots, x_n)\in\mathcal{X}^n$ to learn from. Our goal is to learn the HMM parameters $\theta = (\mu, \nu)$ jointly. Gradient descent on the maximum likelihood objective: $$ \mathcal{L}(\theta) = p_\theta(x_1, \ldots, x_n)\;, $$ is a valid approach, as gradients can be computed in a filtering fashion – see [Krishnamurthy §4.2]. However, the EM often is the method of choice for bootstrapping this optimisation process. Below, we detail both the Expecation and Maximisation step after $t$ iterations.
Expectation
We will need to compute the conditional distribution $p(z_k \vert x_{1:n}, \theta_t)$ for every $k=1, \ldots, t$. This is a typical exercise of fixed-interval smoothing. We provide below the recursive implementation of an optimal fixed-interval smoother. For details and derivation, we refer the interested reader to [Krishnamurthy §3.3.5]. Let us denote $\pi_{k\vert n}(z) = p(z_k=z \vert x_{1:n}, \theta_t)$. We easily obtain that: $$ \pi_{k\vert n}(z) = \frac{\pi_k(z)\beta_{k\vert n}(z)}{\sum_{z’\in\mathcal{Z}}\pi_k(z’)\beta_{k\vert n}(z’)}\;. $$ Above, $\pi_k(z) = p(z_k=z\vert x_{1:k}, \theta_t)$ is computed is a forward filter: $$ \pi_k(z) = \eta^{-1} p_{\nu_t}(x_k\vert z_k=z) \sum_{z’\in\mathcal{Z}} p_{\mu_t}(z_k=z \vert z_{k-1}=z’)\pi_{k-1}(z’) $$ where $\eta$ is a normalisation variable, omitted here to reduce clutter. Further, $\beta_{k\vert n}(z) = p(x_{k+1:n}\vert z_k, \theta_t)$ is computed via backward recursion: $$ \beta_{k\vert n}(z) = \sum_{z’\in\mathcal{Z}} p_{\nu_t}(x_{k+1}\vert z_{k+1}=z’) p_{\mu_t}(z_{k+1}=z’\vert z_k=z)\beta_{k+1\vert n}(z’) $$ initialised at $\beta_{n\vert n}(z)=1$. Another smoothed estimate we’ll need $\pi_{k\vert n}(z, z’) := p(z_k=z, z_{k-1}=z’\vert x_{1:n}, \theta_t)$ is computed similarly.
With those computations out of the way, we can now materialise the auxiliary log-likelihood: $$ \mathcal{Q}(\theta\vert \theta_t) = \sum_{k=1}^n \sum_{z\in\mathcal{Z}}\pi_{k\vert n}(z)\log p_\nu(x_{k}\vert z_{k}=z) + \sum_{k=1}^n \sum_{z, z’\in\mathcal{Z}} \pi_{k\vert n}(z, z’)\log p_\mu(z_{k}=z\vert z_{k-1}=z’)\;. $$
Maximisation
We are now ready to maximise $\mathcal{Q}(\theta\vert \theta_t)$. Thanks to the finite nature of the HMM, the parametrisation $\nu$ and $\mu$ of the observation likelihood and transitions kernel, respectively, can simply be stochastic matrices: $$ \begin{aligned} \mu \in \big\{ \mathbf{P} \in \mathbb{R}^{n\times n}, \mathbf{P}_{ij}\geq 0 \; \text{ and } \sum_{j=1}^n \mathbf{P}_{ij} = 1 \big\} \;, \\ \nu \in \big\{ \mathbf{P} \in \mathbb{R}^{n\times m}, \mathbf{P}_{ij}\geq 0 \; \text{ and } \sum_{j=1}^n \mathbf{P}_{ij} = 1 \big\} \;. \end{aligned} $$ For instance, for every $z, z’\in\mathcal{Z}$ and $x\in\mathcal{X}$: $$ \begin{aligned} \mu_{zz’} &= p(z_{t+1}=z’ \vert z_{t+1}=z) \; , \\ \nu_{zx} &= p(x_t = x \vert z_{t}=z) \; . \end{aligned} $$ Rewriting the auxiliary log-likelihood under this convention yields: $$ \mathcal{Q}(\theta\vert \theta_t) = \sum_{k=1}^n \sum_{z\in\mathcal{Z}}\pi_{k\vert n}(z)\log \nu_{zx_k}+ \sum_{k=1}^n \sum_{z, z’\in\mathcal{Z}} \pi_{k\vert n}(z, z’)\log \mu_{zz’}\;. $$ Both parameters $\mu$ and $\nu$ can therefore be updated independently. Solving each program under the stochastic matrix constraint will yield, for every $z, z’\in\mathcal{Z}$ and $x\in\mathcal{X}$: $$ \begin{aligned} [\mu_{t+1}]_{zz’} &= \frac{\sum_{k=1}^n \pi_{k\vert n}(z, z’)}{\sum_{k=1}^n \pi_{k\vert n}(z)}\; , \\ [\nu_{t+1}]_{xz} &= \frac{\sum_{k=1}^n \pi_{k\vert n}(z) \mathbf{1}[x_k=x]}{\sum_{k=1}^n \pi_{k\vert n}(z)}\; . \end{aligned} $$
Resources
The HMM part of this blog-post is highly inspired from [Krishnamurthy §4].