Variational Inference in POMDPs
The goal of this post is to explore from first principles the learning of belief models in partially observable MDPs. We will start with a quick refresher on variational inference and then apply it to state estimation in POMDPs. Specifically, we will derive the update rule used to train Dreamer-like world models.
Variational Inference
In this post, we are interested in latent variable models. We consider couples of random variables $(x, z)$ where $z$ denotes the latent variable and $x$ the observed one. Inference refers to the estimation of $p(z\vert x)$. Often, it does not exist in closed form, and the best one can do is try to approximate it. Given some family of distribution $\mathcal{P}$, variational inference is about finding the best approximation of $p(z\vert x)$ within $\mathcal{P}$: $$ q^\star \in \argmin_{q\in\mathcal{P}} \text{KL}(q(z\vert x) \, \| \,p(z\vert x)) \;. $$
Maximum-likelihood
We will be interested in variational inference for maximum likelihood estimation in latent variable models. Assuming the relevant distributions are parametrised by some $\theta$, we want to maximise the observation likelihood $ p_\theta(x)$. The marginal $p_\theta(x)$ could be obtained via integration by the law of total probabilities: $$ \begin{aligned} \log p_\theta(x) &= \log \int_{z} p_\theta(x\vert z)p_\theta(z)dz \;. \end{aligned} $$ This integral rarely exists in closed form. We saw in our EM post how this objective could be optimised nonetheless, given that the posterior $p_\theta(z\vert x)$ is known. Sadly, this is again a questionable assumption – for instance, it fails whenever $p_\theta(x\vert z)$ is represented via a neural-network. The way forward involves variational inference: introducing a variational distribution $q_\phi(z\vert x)$. Indeed, we have:
$$ \tag{1} \log p_\theta(x) = \text{KL}(q_\phi \, \| \, p_\theta ) + \underbrace{\mathbb{E}_{q_\phi}\left[\log\frac{p_\theta(x\vert z)p_\theta(z)}{q_\phi(z\vert x)}\right]}_{\mathcal{V}(\theta, \phi)} \;. $$
Because the Kullback-Leibler divergence is non-negative we have $ \log p_\theta(x) \geq \mathcal{V}(\theta, \phi)\;. $ For this reason, $\mathcal{V}$ is often called the variational lower-bound, or the evidence lower-bound (some communities call $\log p_\theta(x)$ the evidence). It does not require computing a potentially intractable integral, and we can derive efficient unbiased estimators for it (see below). That’s it, really: we can now focus on this objective, relying on the fact that it lower-bounds the maximum-likelihood (and hoping it will not be too far off).
Estimators
Let us rewrite the variational lower-bound as follows: $$ \mathcal{V}(\theta, \phi) = \mathbb{E}_{q_\phi}\left[\, \log p_\theta(x\vert z)\right] - \text{KL}(q_\phi(z\vert x) \,\|\, p_{\theta}(z)) $$
The KL often exists in closed-form – e.g. when $q_\phi$ and the marginal $p_{\theta}(z)$ are Gaussians. We make this assumption from now on. Further, we also remove the dependency $\theta$ from the latter, and simply consider this term as a regulariser for $\phi$ – for instance, setting $p_{\theta}(z) = \mathcal{N}(z\vert 0, 1)$.
The first term can be more annoying. Gradients w.r.t. $\theta$ are uneventful and can easily be approximated via sample average. Gradients w.r.t $\phi$ are slightly more annoying, as getting unbiased estimators will typically require the use of the log-derivative trick. Even though unbiased, they come with high variance. Instead, it is common to use the reparametrisation trick. If our choice for $q_\phi$ allows (e.g. a location-scale distribution), we will write samples from it as a direct transformation of some random variable $\varepsilon$, drawn from a “simple” distribution. The usual examples come from Gaussian reparametrisation; if $q_\phi(z\vert x) = \mathcal{N}(z\vert \mu_\phi(x), \sigma^2)$ then one can write $z = \mu_\phi(x) + \sigma\varepsilon$ with $\varepsilon \sim\mathcal{N}(0, 1)$. In general, one writes $z=g_\phi(x, \varepsilon)$. This allows to write: $$ \nabla_{\theta, \phi} \, \mathcal{V}(\theta, \phi) = \nabla_{\theta, \phi}\mathbb{E}_{\varepsilon}[\log p_\theta(x \vert g_\phi(x, \varepsilon))]- \nabla_{\theta, \phi} \text{KL}(q_\phi(z\vert x) \,\|\, p(z))\; , $$
for which sample-average estimators are straight-forward.
Learning in POMDPs
Setting
We now turn to using variational inference for belief estimation in a POMDP denoted $\mathcal{M}=(\mathcal{S}, \mathcal{A}, \mathcal{O}, p, q, r)$. One approach to solve $\mathcal{M}$ is to express the equivalent belief MDP, where the state is replaced by the belief $b_t = \mathbb{P}(s_t \vert o_{1:t}, a_{1:t-1})$. Estimating the belief requires both the transition $p$ and emission kernel $q$ – which are unknown in a RL setting. Model-based RL approaches attempt to learn predictive models for the belief, so that it can be used for explicit planning. We will focus here only on the former.
Learning belief models
We wish to learn a probabilistic model $q_\phi$ for the belief from a trajectory $\{o_{1:t}, a_{1:t-1}\}$. To be useful at planning time, we require said model to be causal. Formally, we will use a filtering posterior: $$ \tag{3} q_\phi(s_t\vert o_{1:t}, a_{1:t-1}) := \prod_{k=2}^K q_\phi(s_k \vert o_{1:k}, a_{1:k-1})q_\phi(s_1)\;. $$
It is unclear how to learn $q_\phi$ given that states $\{s_{i}\}_{i=1}^t$ are not observed. Similarly to the previous section, it will be introduced as a variational distribution via a maximum-likelihood scheme. Indeed, we set out to maximise the likelihood of observing $\{o_i\}_{i=1}^t$, conditioned on $\{a_i\}_{i=1}^{t-1}$, under some model $p_\theta$. Following the protocol detailed in the last section, one can establish the following variational bound:
$$ \tag{4} \log p_\theta(o_{1:t}\vert a_{1:t-1}) \geq \sum_{k=1}^t \mathbb{E}_{q_\phi}\Big[\log p_\theta(o_k \vert s_k) - \text{KL}(p_\theta(s_k\vert s_{k-1}, a_{k-1}) \, \| \, q_\phi(s_k\vert o_{1:k}, a_{1:k-1}))\Big]\;. $$
Let us dissect (4). Observe that we maintain two models: one for the transition, one for the emission kernel. The first term of the r.h.s maximises the likelihood of observation $p_\theta(o_k\vert s_k)$. The second term promotes the consistency of the belief model across transitions measured by the transition model $p_{\theta}(s_k\vert s_{k-1}, a_{k-1})$.
The variational distribution (the belief model) can be seen as an encoder for the trajectory $\{o_{1:t}, a_{1:t-1}\}$. We use it to generate a sequence of beliefs $\{s_{1:t}\}$ sampled according to (3). This yields an estimator for (4); by the reparametrisation trick, we also obtain efficient estimator for its gradients.
The belief model is often built using recurrent model, so that $ q_\phi(s_t\vert o_{1:t}, a_{1:t-1}) := \prod_{k=2}^K q_\phi(s_k \vert s_{k-1}, o_{k}, a_{k-1})\;. $ Altogether, this is what powers modern model-based approaches like the Dreamer family of models and algorithms. The lower bound (4) appears already in [2] . Follow-up papers mostly refine the belief model itself (not the learning algorithm) – by e.g., discretising the belief space, enabling deterministic belief updates, etc.
References
The variational inference part of this blog post and is condensed version of:
[1] Auto-Encoding Variational Bayes. Kingma and Welling, 2013.
Its application to POMDP is taken from:
[2] Learning Latent Dynamics for Planning from Pixels. Hafner et al, 2019.