Variational Inference in POMDPs

The goal of this post is to explore from first principles the learning of belief models in partially observable MDPs. We will start with a quick refresher on variational inference and then apply it to state estimation in POMDPs. Specifically, we will derive the update rule used to train Dreamer-like world models.

Variational Inference

In this post, we are interested in latent variable models. We consider couples of random variables $(x, z)$ where $z$ denotes the latent variable and $x$ the observed one. Inference refers to the estimation of $p(z\vert x)$. Often, it does not exist in closed form, and the best one can do is try to approximate it. Given some family of distribution $\mathcal{P}$, variational inference is about finding the best approximation of $p(z\vert x)$ within $\mathcal{P}$: $$ q^\star \in \argmin_{q\in\mathcal{P}} \text{KL}(q(z\vert x) \, \| \,p(z\vert x)) \;. $$

Maximum-likelihood

$\quad$ Our post about the Expectation-Maximisation algorithm can provide a good warm-up here.

We will be interested in variational inference for maximum likelihood estimation in latent variable models. Assuming the relevant distributions are parametrised by some $\theta$, we want to maximise the observation likelihood $ p_\theta(x)$. The marginal $p_\theta(x)$ could be obtained via integration by the law of total probabilities: $$ \begin{aligned} \log p_\theta(x) &= \log \int_{z} p_\theta(x\vert z)p_\theta(z)dz \;. \end{aligned} $$ This integral rarely exists in closed form. We saw in our EM post how this objective could be optimised nonetheless, given that the posterior $p_\theta(z\vert x)$ is known. Sadly, this is again a questionable assumption – for instance, it fails whenever $p_\theta(x\vert z)$ is represented via a neural-network. The way forward involves variational inference: introducing a variational distribution $q_\phi(z\vert x)$. Indeed, we have:

$$ \tag{1} \log p_\theta(x) = \text{KL}(q_\phi \, \| \, p_\theta ) + \underbrace{\mathbb{E}_{q_\phi}\left[\log\frac{p_\theta(x\vert z)p_\theta(z)}{q_\phi(z\vert x)}\right]}_{\mathcal{V}(\theta, \phi)} \;. $$

Proof

Because the Kullback-Leibler divergence is non-negative we have $ \log p_\theta(x) \geq \mathcal{V}(\theta, \phi)\;. $ For this reason, $\mathcal{V}$ is often called the variational lower-bound, or the evidence lower-bound (some communities call $\log p_\theta(x)$ the evidence). It does not require computing a potentially intractable integral, and we can derive efficient unbiased estimators for it (see below). That’s it, really: we can now focus on this objective, relying on the fact that it lower-bounds the maximum-likelihood (and hoping it will not be too far off).

The EM revisited

Estimators

Let us rewrite the variational lower-bound as follows: $$ \mathcal{V}(\theta, \phi) = \mathbb{E}_{q_\phi}\left[\, \log p_\theta(x\vert z)\right] - \text{KL}(q_\phi(z\vert x) \,\|\, p_{\theta}(z)) $$

The KL often exists in closed-form – e.g. when $q_\phi$ and the marginal $p_{\theta}(z)$ are Gaussians. We make this assumption from now on. Further, we also remove the dependency $\theta$ from the latter, and simply consider this term as a regulariser for $\phi$ – for instance, setting $p_{\theta}(z) = \mathcal{N}(z\vert 0, 1)$.

The first term can be more annoying. Gradients w.r.t. $\theta$ are uneventful and can easily be approximated via sample average. Gradients w.r.t $\phi$ are slightly more annoying, as getting unbiased estimators will typically require the use of the log-derivative trick. Even though unbiased, they come with high variance. Instead, it is common to use the reparametrisation trick. If our choice for $q_\phi$ allows (e.g. a location-scale distribution), we will write samples from it as a direct transformation of some random variable $\varepsilon$, drawn from a “simple” distribution. The usual examples come from Gaussian reparametrisation; if $q_\phi(z\vert x) = \mathcal{N}(z\vert \mu_\phi(x), \sigma^2)$ then one can write $z = \mu_\phi(x) + \sigma\varepsilon$ with $\varepsilon \sim\mathcal{N}(0, 1)$. In general, one writes $z=g_\phi(x, \varepsilon)$. This allows to write: $$ \nabla_{\theta, \phi} \, \mathcal{V}(\theta, \phi) = \nabla_{\theta, \phi}\mathbb{E}_{\varepsilon}[\log p_\theta(x \vert g_\phi(x, \varepsilon))]- \nabla_{\theta, \phi} \text{KL}(q_\phi(z\vert x) \,\|\, p(z))\; , $$

for which sample-average estimators are straight-forward.

Variational auto-encoders

Learning in POMDPs

$\quad$ Take a look at the POMDP post for a refresher.

Setting

We now turn to using variational inference for belief estimation in a POMDP denoted $\mathcal{M}=(\mathcal{S}, \mathcal{A}, \mathcal{O}, p, q, r)$. One approach to solve $\mathcal{M}$ is to express the equivalent belief MDP, where the state is replaced by the belief $b_t = \mathbb{P}(s_t \vert o_{1:t}, a_{1:t-1})$. Estimating the belief requires both the transition $p$ and emission kernel $q$ – which are unknown in a RL setting. Model-based RL approaches attempt to learn predictive models for the belief, so that it can be used for explicit planning. We will focus here only on the former.

$\quad$ Be prepared for overloaded notations when it comes to distributions. To limit confusion, learned distributions will be subscript by a parameter (e.g. $p_\theta$) while ground truth are not (e.g. $p$).

Learning belief models

We wish to learn a probabilistic model $q_\phi$ for the belief from a trajectory $\{o_{1:t}, a_{1:t-1}\}$. To be useful at planning time, we require said model to be causal. Formally, we will use a filtering posterior: $$ \tag{3} q_\phi(s_t\vert o_{1:t}, a_{1:t-1}) := \prod_{k=2}^K q_\phi(s_k \vert o_{1:k}, a_{1:k-1})q_\phi(s_1)\;. $$

It is unclear how to learn $q_\phi$ given that states $\{s_{i}\}_{i=1}^t$ are not observed. Similarly to the previous section, it will be introduced as a variational distribution via a maximum-likelihood scheme. Indeed, we set out to maximise the likelihood of observing $\{o_i\}_{i=1}^t$, conditioned on $\{a_i\}_{i=1}^{t-1}$, under some model $p_\theta$. Following the protocol detailed in the last section, one can establish the following variational bound:

$$ \tag{4} \log p_\theta(o_{1:t}\vert a_{1:t-1}) \geq \sum_{k=1}^t \mathbb{E}_{q_\phi}\Big[\log p_\theta(o_k \vert s_k) - \text{KL}(p_\theta(s_k\vert s_{k-1}, a_{k-1}) \, \| \, q_\phi(s_k\vert o_{1:k}, a_{1:k-1}))\Big]\;. $$

Proof

Let us dissect (4). Observe that we maintain two models: one for the transition, one for the emission kernel. The first term of the r.h.s maximises the likelihood of observation $p_\theta(o_k\vert s_k)$. The second term promotes the consistency of the belief model across transitions measured by the transition model $p_{\theta}(s_k\vert s_{k-1}, a_{k-1})$.

The variational distribution (the belief model) can be seen as an encoder for the trajectory $\{o_{1:t}, a_{1:t-1}\}$. We use it to generate a sequence of beliefs $\{s_{1:t}\}$ sampled according to (3). This yields an estimator for (4); by the reparametrisation trick, we also obtain efficient estimator for its gradients.

Multi-steps consistency

The belief model is often built using recurrent model, so that $ q_\phi(s_t\vert o_{1:t}, a_{1:t-1}) := \prod_{k=2}^K q_\phi(s_k \vert s_{k-1}, o_{k}, a_{k-1})\;. $ Altogether, this is what powers modern model-based approaches like the Dreamer family of models and algorithms. The lower bound (4) appears already in [2] . Follow-up papers mostly refine the belief model itself (not the learning algorithm) – by e.g., discretising the belief space, enabling deterministic belief updates, etc.

References

The variational inference part of this blog post and is condensed version of:

[1] Auto-Encoding Variational Bayes. Kingma and Welling, 2013.

Its application to POMDP is taken from:

[2] Learning Latent Dynamics for Planning from Pixels. Hafner et al, 2019.