VFE approximation for Gaussian processes, the gory details

Mon 20 August 2018

This post gives the VFE Gaussian process derivation in detail. The implementation details are given in another post.

\(\newcommand{\X}{\mathbf{X}}\) \(\newcommand{\Z}{\mathbf{Z}}\) \(\newcommand{\W}{\mathbf{W}}\) \(\newcommand{\Wt}{\mathbf{W}^T}\) \(\newcommand{\M}{\mathbf{M}}\) \(\renewcommand{\S}{\mathbf{S}}\) \(\newcommand{\Minv}{\mathbf{M}^{-1}}\) \(\newcommand{\U}{\mathbf{U}}\) \(\newcommand{\Vt}{\mathbf{V}^T}\) \(\newcommand{\I}{\mathbf{I}}\) \(\newcommand{\C}{\mathbf{C}}\) \(\renewcommand{\L}{\mathbf{L}}\) \(\newcommand{\Sig}{\boldsymbol{\Sigma}}\) \(\newcommand{\A}{\mathbf{A}}\) \(\newcommand{\T}{\mathbf{T}}\) \(\newcommand{\G}{\mathbf{G}}\) \(\newcommand{\R}{\mathbf{R}}\) \(\newcommand{\Linv}{\mathbf{L}^{-1}}\) \(\newcommand{\K}{\mathbf{K}}\) \(\newcommand{\Kxx}{\mathbf{K}_{xx}}\) \(\newcommand{\Kxs}{\mathbf{K}_{x*}}\) \(\newcommand{\Ksx}{\mathbf{K}_{*x}}\) \(\newcommand{\Kss}{\mathbf{K}_{**}}\) \(\newcommand{\Kzs}{\mathbf{K}_{z*}}\) \(\newcommand{\Ksz}{\mathbf{K}_{*z}}\) \(\newcommand{\Kzz}{\mathbf{K}_{zz}}\) \(\newcommand{\Kxz}{\mathbf{K}_{xz}}\) \(\newcommand{\Kzx}{\mathbf{K}_{zx}}\) \(\newcommand{\Kzzinv}{\mathbf{K}_{zz}^{-1}}\) \(\newcommand{\Qxx}{\mathbf{Q}_{xx}}\) \(\newcommand{\Qxs}{\mathbf{Q}_{x*}}\) \(\newcommand{\Qsx}{\mathbf{Q}_{*x}}\) \(\newcommand{\Qsz}{\mathbf{Q}_{*z}}\) \(\newcommand{\Qzs}{\mathbf{Q}_{z*}}\) \(\newcommand{\yv}{\mathbf{y}}\) \(\newcommand{\xv}{\mathbf{x}}\) \(\newcommand{\sv}{\mathbf{s}}\) \(\newcommand{\bv}{\mathbf{b}}\) \(\newcommand{\mv}{\mathbf{m}}\) \(\newcommand{\ev}{\mathbf{e}}\) \(\newcommand{\tv}{\mathbf{t}}\) \(\newcommand{\zv}{\mathbf{z}}\) \(\newcommand{\uv}{\mathbf{u}}\) \(\newcommand{\fv}{\mathbf{f}}\) \(\newcommand{\fvx}{\mathbf{f}_x}\) \(\newcommand{\fvs}{\mathbf{f}_*}\) \(\newcommand{\fvz}{\mathbf{f}_z}\) \(\newcommand{\fvxzbar}{\bar{\mathbf{f}}_{xz}}\) \(\newcommand{\xvs}{\mathbf{x}_*}\) \(\newcommand{\epsv}{\boldsymbol\epsilon}\) \(\newcommand{\muv}{\boldsymbol\mu}\) \(\newcommand{\bzero}{\mathbf{0}}\) \(\newcommand{\normal}[3]{\mathcal{N}\left(#1 \mid #2 \,,\, #3\right)}\) \(\newcommand{\KL}[2]{\mathrm{KL}\left[\, #1 \, || \, #2 \, \right]}\) \(\newcommand{\Exp}[2]{\mathbb{E}_{#1}\left[#2\right]}\) \(\newcommand{\Tr}[1]{\text{Tr}\left[ #1 \right]}\) \(\newcommand{\Cov}[1]{\mathbb{C}\text{ov}\left[ #1 \right]}\) \(\DeclareMathOperator*{\argmin}{\mathbf{argmin}\,}\)

Throughout, we consider the covariance or mean function hyperparameters \(\theta\) to be conditioned on implicitly, unless stated otherwise.

$$ \begin{aligned} p(\fvx) &= \normal{\fvx}{\bzero}{\Kxx} \\ p(\yv \mid \fvx) &= \normal{\yv}{\fvx}{\sigma^2 \I} \\ p(\fvx \mid \yv) &= \normal{\fvx}{\Kxx [\Kxx + \sigma^2 \I]^{-1} \yv}{\Kxx - \Kxx [\Kxx + \sigma^2 \I]^{-1} \Kxx} \\ p(\fvs \mid \yv) &= \normal{\fvs}{\Ksx [\Kxx + \sigma^2 \I]^{-1} \yv}{\Kxx - \Ksx [\Kxx + \sigma^2 \I]^{-1} \Kxs} \\ p(\fvs \mid \fvx) &= \normal{\fvs}{\Ksx \Kxx^{-1} \fvx}{\Kxx - \Ksx \Kxx^{-1} \Kxs} \\ \end{aligned} $$

The derivation starts with the equation for the GP predictive distribution, \(p(\fvs \mid \yv)\). First write it as the marginal distribution of the joint \(p(\fvs, \fvx \mid \yv)\),

\begin{align*} p(\fvs \mid \yv) &= \int p(\fvs, \fvx \mid \yv) d\fvx \\ &= \int p(\fvs \mid \fvx) p(\fvx \mid \yv) d\fvx \\ &= \int p(\fvs \mid \fvx) p(\fvx \mid \yv) d\fvx \,. \end{align*}

First we factor the joint distribution, then cancel \(\yv\), since \(\yv\) and \(\fvx\) contain the same information on \(\fvs\). Both the distributions in the integrand are known.

The VFE approximation begins by applying a variational approximation to \(p(\fvx, \fvz \mid \yv)\), which we denote \(p(\fvx, \fvz)\),

\begin{equation}\label{eq:gpcond} \tilde{p}(\fvs \mid \yv) = \iint p(\fvs \mid \fvx, \fvz) q(\fvx, \fvz) d\fvz d\fvx \,. \end{equation}

The tilde denotes that the distribution is an approximation. This approximation is exact when the Kullback-Leibler divergence (KL) is equal to zero,

\begin{equation} \text{If}\hspace{2mm} \KL{q(\fvx, \fvz)}{p(\fvx, \fvz \mid \yv)} = 0\,, \hspace{1mm} \text{then} \hspace{2mm} q(\fvx, \fvz) = p(\fvx, \fvz \mid \yv) \end{equation}

The larger the KL is, the worse the approximation becomes. We choose to factor the variational posterior as

\begin{align} q(\fvx, \fvz) &= q(\fvx \mid \fvz) q(\fvz) \\ &= p(\fvx \mid \fvz) q(\fvz) \,. \end{align}

We choose to use \(p(\fvx \mid \fvz)\) for \(q(\fvx \mid \fvz)\), which is known. Then we choose a multivariate normal with unknown mean and variance for \(q(\fvz) = \normal{\fvz}{\boldsymbol\mu}{\mathbf{A}}\). After we rearranging the KL divergence and simplifying, we will find values for \(\boldsymbol\mu\) and \(\mathbf{A}\) that minimize the KL divergence. So, our goal is to find

\begin{equation} \argmin_{q, \zv} \KL{q(\fvx, \fvz)}{p(\fvx, \fvz \mid \yv)} \,. \end{equation}

We want to find both the inducing point locations \(\zv\), and the mean and covariance of the variational distribution \(q(\fvz)\) which minimize the KL divergence.

Next, we expand and then simplify the expression for this KL divergence,

\begin{align} \KL{q(\fvx, \fvz)}{p(\fvx, \fvz \mid \yv)} &= \Exp{q(\fvx, \fvz}{\log \frac{q(\fvx, \fvz)}{p(\fvx, \fvz \mid \yv}} \\ &= \Exp{q(\fvx, \fvz}{\log q(\fvx, \fvz) } - \Exp{q(\fvx, \fvz)}{\log p(\fvx, \fvz \mid \yv)} \\ &= \Exp{q(\fvx, \fvz}{\log q(\fvx, \fvz) } - \Exp{q(\fvx, \fvz)}{\log \frac{p(\fvx, \fvz, \yv)}{p(\yv)}} \\ &= \Exp{q(\fvx, \fvz}{\log q(\fvx, \fvz) } - \Exp{q(\fvx, \fvz)}{\log p(\fvx, \fvz, \yv)} + \log p(\yv) \,. \end{align}

Next, we rearrange to have \(\log p(\yv)\) on the left hand side,

\begin{align} \log p(\yv) &= \Exp{q(\fvx, \fvz)}{\log p(\fvx, \fvz, \yv)} - \Exp{q(\fvx, \fvz)}{\log q(\fvx, \fvz) } + \KL{q(\fvx, \fvz)}{p(\fvx, \fvz \mid \yv)} \\ &\geq \Exp{q(\fvx, \fvz)}{\log p(\fvx, \fvz, \yv)} - \Exp{q(\fvx, \fvz)}{\log q(\fvx, \fvz) } \end{align}

Since the KL divergence is greater than or equal to zero, we find an expression for the lower bound of the marginal log-likelihood, \(\log p(\yv)\). The expression on the right-hand side is the evidence lower bound (ELBO). Then we can simplify further,

\begin{align} \log p(\yv) &\geq \Exp{q(\fvx, \fvz)}{\log p(\fvx, \fvz, \yv)} - \Exp{q(\fvx, \fvz) }{\log q(\fvx, \fvz) } \\ &\geq \Exp{q(\fvx, \fvz)}{\log \frac{ p(\fvx, \fvz, \yv) }{ q(\fvx, \fvz) } } \\ &\geq \Exp{q(\fvx, \fvz)}{\log \frac{ p(\yv \mid \fvx, \fvz) p(\fvz) } { q(\fvz) } } \\ &\geq \Exp{p(\fvx \mid \fvz)q(\fvz)}{\log \frac{ p(\yv \mid \fvx) p(\fvz) } { q(\fvz) } } \\ &\geq \Exp{p(\fvx \mid \fvz)q(\fvz)}{\log p(\yv \mid \fvx) } + \Exp{p(\fvx \mid \fvz)q(\fvz)}{\log \frac{ p(\fvz) }{ q(\fvz) }} \\ \end{align}

where in the last line, we write \(p(\yv \mid \fvx, \fvz) = p(\yv \mid \fvx)\), since knowing the inducing points \(\fvz\) in addition to \(\fvz\) provide no extra information to the conditional distribution. Then we rewrite the expected value as an integral, and define \(F(q, \zv)\) as the variational lower bound, so \(F(q, \zv) \geq \log p(\yv)\).

\begin{align} F(q, \zv) &= \iint \left[ \log p(\yv \mid \fvx) + \log \frac{p(\fvz)}{q(\fvz)} \right] p(\fvx \mid \fvz)q(\fvz) d\fvx d\fvz \\ &= \int \left[ \int \log(p(\yv \mid \fvx))p(\fvx \mid \fvz) d\fvx + \int p(\fvx \mid \fvz) \log \frac{p(\fvz)}{q(\fvz)} d \fvx \right] q(\fvz) d\fvz \\ &= \int \left[ \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} + \int p(\fvx \mid \fvz) \log \frac{p(\fvz)}{q(\fvz)} d \fvx \right] q(\fvz) d\fvz \\ &= \int \left[ \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} + \log \frac{p(\fvz)}{q(\fvz)} \right] q(\fvz) d\fvz \\ \end{align}

It is possible to do the expectation inside the integrand analytically, \(\Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)}\). Recall that \(p(\yv \mid \fvx)= \normal{\yv}{\fvx}{\sigma^2 \I}\). We will compute this expectation as the next step,

\begin{equation} \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} = \Exp{p(\fvx \mid \fvz)}{ -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2}(\yv - \fvx)^{T} (\sigma^2 \I)^{-1} (\yv - \fvx) } \,. \end{equation}

The trick here, is to rewrite the quadratic term as a trace. We use two trace identities. The first is that a trace of a constant (when considered a \(1\times1\) matrix) is equal to itself, \(\Tr{c} = c\). The second allows us to reorder the matrices, \(\Tr{\mathbf{A}\mathbf{B}} = \Tr{\mathbf{B}\mathbf{A}}\) for appropriately sized matrices. Written as a trace, the quadratic term is,

\begin{align} \Tr{(\yv - \fvx)^{T} (\sigma^2 \I)^{-1} (\yv - \fvx)} &= \Tr{(\sigma^2 \I)^{-1}(\yv - \fvx)(\yv - \fvx)^{T}} \\ &= \frac{1}{\sigma^2}\Tr{(\yv - \fvx)(\yv - \fvx)^{T}} \\ &= \frac{1}{\sigma^2}\Tr{ \yv\yv^T - 2\yv\fvx + \fvx\fvx^{T}} \end{align}

Plugging this part back into expectation, and distributing the expected value,

\begin{align} \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} &= \Exp{p(\fvx \mid \fvz)}{ -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Tr{ \yv\yv^T - 2\yv\fvx + \fvx\fvx^{T}}} \\ &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Tr{\yv\yv^T - 2\yv \Exp{p(\fvx \mid \fvz)}{\fvx}^{T} + \Exp{p(\fvx \mid \fvz)}{\fvx \fvx^T}} \,. \end{align}

We know that the distribution \(p(\fvx \mid \fvz) = \normal{\fvx}{\Kxz\Kzzinv\fvz}{\Kxz\Kzzinv\Kzx}\). To simplify the notation, we define \(\fvxzbar \equiv \Kxz\Kzzinv\fvz\), and \(\Qxx \equiv \Kxz\Kzzinv\Kzx\). More generally, we define \(\mathbf{Q}_{aa} = \K_{ab}\K_{bb}^{-1}\K_{ba}\). Using this notation, \(p(\fvx \mid \fvz) = \normal{\fvz}{\fvxzbar}{\Kxx - \Qxx}\). Immediately, we see that \(\Exp{p(\fvx \mid \fvz)}{\fvx} = \fvxzbar\). Then, note that

\begin{align} \Cov{\fvx\,, \fvx} &= \Exp{}{(\fvx - \Exp{}{\fvx})(\fvx - \Exp{}{\fvx})^T} = \Exp{}{\fvx\fvx^T} - \Exp{}{\fvx}\Exp{}{\fvx^{T}} \\ &= \Exp{}{(\fvx - \fvxzbar)(\fvx - \fvxzbar)^T} = \Exp{}{\fvx\fvx^T} - \fvxzbar\fvxzbar^T \\ \end{align}

So, \(\Exp{}{\fvx\fvx^T} = \Kxx - \Qxx + \fvxzbar\fvxzbar^T\). Now, we plug these results back in and we can simplify this expression considerably,

\begin{align} \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Tr{\yv\yv^T - 2\yv \Exp{p(\fvx \mid \fvz)}{\fvx}^{T} + \Exp{p(\fvx \mid \fvz)}{\fvx \fvx^T}} \\ &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Tr{\yv\yv^T - 2\yv \fvxzbar^T + \fvxzbar\fvxzbar^T + \Kxx - \Qxx } \\ &= -\frac{n}{2}\log(2\pi\sigma^2) - \frac{1}{2\sigma^2}\Tr{\yv\yv^T - 2\yv \fvxzbar^T + \fvxzbar\fvxzbar^T} - \frac{1}{2\sigma^2}\Tr{\Kxx - \Qxx } \\ &= \log \normal{\yv}{\fvxzbar}{\sigma^2 \I} - \frac{1}{2\sigma^2}\Tr{\Kxx - \Qxx } \,. \end{align}

Now that we have computed the expectation \(\Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)}\), we plug it back into our expression for the marginal likelihood lower bound and simplify further,

\begin{align} F(q, \zv) &= \int \left[ \Exp{p(\fvx \mid \fvz)}{\log p(\yv \mid \fvx)} + \log \frac{p(\fvz)}{q(\fvz)} \right] q(\fvz) d\fvz \\ &= \int \left[ \log \normal{\yv}{\fvxzbar}{\sigma^2 \I} - \frac{1}{2\sigma^2}\Tr{\Kxx - \Qxx } + \log \frac{p(\fvz)}{q(\fvz)} \right] q(\fvz) d\fvz \\ &= \int q(\fvz) \log\left(\frac{ \normal{\yv}{\fvxzbar}{\sigma^2 \I} p(\fvz)}{q(\fvz)}\right) d\fvz - \frac{1}{2\sigma^2}\Tr{\Kxx - \Qxx } \,. \end{align}

The term on the left is an information inequality. Specifically, note that \(\KL{q}{p} = -\int q \log\frac{p}{q}\), which equals zero when \(q = p\). Therefore, \(F(q, \zv)\) is optimal when \(q(\fvz) \propto \normal{\yv}{\fvxzbar}{\sigma^2\I}p(\fvz)\). Taking this product and normalizing leads to the optimal,

\begin{equation} q^*(\fvz) = \normal{\fvz}{\frac{1}{\sigma^2}\Kzz\boldsymbol\Sigma^{-1}\Kzz\yv}{\Kzz\boldsymbol\Sigma^{-1}\Kzz} \,, \end{equation}

where \(\boldsymbol\Sigma = \Kzz + \frac{1}{\sigma^2}\Kzx\Kxz\). Plugging this in, and computing the integral over \(\fvz\) analytically (the integrand is products of normals) yeilds the final result for the marginal likelihood lower bound,

\begin{equation} \log p(\yv) \geq \log \normal{\yv}{\bzero}{\Qxx + \sigma^2 \I} - \frac{1}{2\sigma^2}\Tr{\Kxx - \Qxx } \,. \end{equation}

This is identical to the standard Gaussian process marginal likelihood, except for the additional trace term. When the trace of \(\Kxx - \Qxx = 0\), then the two are equal, and the approximation over \(q(\fvz)\) and the inducing point input locations is exact.

Lastly, we must return to the original Gaussian process conditional, Eq.~\ref{eq:gpcond}, which motivated our approximation. The final series of steps is to plug in the optimal \(q^*(\fvz)\) and simplify,

\begin{align} \tilde{p}(\fvs \mid \yv) &= \iint p(\fvs \mid \fvx, \fvz) q(\fvx, \fvz) d\fvz d\fvx \\ &= \iint p(\fvs \mid \fvx, \fvz) p(\fvx \mid \fvz) q^*(\fvz) d\fvz d\fvx \\ &= \int p(\fvs \mid \fvz) q^*(\fvz) d\fvz \\ &= \normal{\fvs}{\frac{1}{\sigma^2} \Ksz \boldsymbol\Sigma^{-1}\Kzx \yv}{\Kss - \Ksz\Kzzinv\Kzs + \Ksz\Kzz\boldsymbol\Sigma^{-1}\Kzz\Kzs} \\ &= \normal{\fvs}{\Qsz[\Qxx + \sigma^2\I]^{-1}\yv}{\Kss - \Qsx[\Qxx + \sigma^2 \I]^{-1}\Qxs} \,. \end{align}

The matrix inversions in this final expression are done on \(m \times m\) matrices, instead of \(n \times n\), where \(m < n\). Overall, the VFE approximation has a computational cost of \(\mathcal{O}(nm^2)\) and memory usage of \(\mathcal{O}(nm)\). How exactly this is accomplished is via the Woodbury matrix lemma, which allows us to rewrite

\begin{equation} (\sigma^2\I + \Qxx)^{-1} = (\sigma^2 \I)^{-1} - (\sigma^2 \I)^{-1} \Kxz (\Kzzinv + \Kzx (\sigma^2 \I)^{-1} \Kxz)^{-1} \Kzx (\sigma^2 \I)^{-1} \,. \end{equation}

The inverse required is now \(\Kzzinv\), which is \(m \times m\).