Notes for “Auto-Encoding Variational Bayes”
My notes for the “Auto-Encoding Variational Bayes” paper. Feel free to ask questions on my telegram channel
Method
Problem scenario
Let us consider some independently and identically distributed (i.i.d) dataset \(\mathbf{X}=\{\mathbf{x[1]}, \mathbf{x[2]}, ..., \mathbf{x[N]}\}\). We assume that the data involves an unobserved random variable \(\mathbf{z}\). Then, the process consists of two steps:
- \(\mathbf{z[i]}\) is generated from some prior distribution \(\mathbf{p_{\theta*}(z)}\) 
- \(\mathbf{x[i]}\) is generated from some conditional distribution \(\mathbf{p_{\theta*}(x|z)}\) 
Unfortunately, the true parameters \(\mathbf{ \theta* }\) and the latent variables \(\mathbf{z[i]}\) are unknown to us. Additionally, \(\mathbf{p_\theta(z|x)}\) is intractable, so we approximate it with \(\mathbf{ q_\phi(z|x) }\). The model \(\mathbf{ q_\phi(z|x) }\) is a probabilistic encoder model parameterized by \(\phi\) . Similarly, \(\mathbf{p_\theta(x|z)}\) is a probabilistic decoder model parameterized by \(\mathbf{\theta}\).
The variational lower bound
Ideally, we would like to optimize the marginal likelihoods of the dataset X:
\[ \mathbf{ \log p_\theta(x[1], \cdots,x[N])=\sum_{i=1}^N{\log p_\theta(x[i])} } \]
Where each term can be rewritten as:
\[ \mathbf{ \log p_\theta(x[i]) = D_{KL}(q_\phi(z|x[i])||p_\theta(z|x[i])) + \mathcal{L}(\theta, \phi, x[i]) } \]
\[ \begin{aligned} \log p_\theta(\mathbf{x}) &= \log p_\theta(\mathbf{x})\int{ \mathbf{q_\phi(z|x)}d\mathbf{z}} \qquad \text{(Multiplying by 1)}\\ &= \int{\log p_\theta(\mathbf{x}) \mathbf{q_\phi(z|x)}d\mathbf{z}}\\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log p_\theta(\mathbf{x}) \right] \qquad \text{(Definition of Expectation)}\\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{p_\theta(\mathbf{x, z})}{p_\theta(\mathbf{z|x})} \right] \qquad \text{(The chain rule of probability)} \\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{p_\theta(\mathbf{x, z})}{p_\theta(\mathbf{z|x})} \cdot \frac{\mathbf{q_\phi(z|x)}}{\mathbf{q_\phi(z|x)}} \right]\\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{ \mathbf{q_\phi(z|x)} }{ p_\theta(\mathbf{z|x}) } \right] + \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{ p_\theta(\mathbf{x, z}) }{ \mathbf{q_\phi(z|x)} } \right]\\ &= D_{KL}( \mathbf{q_\phi(z|x)} || p_\theta(\mathbf{z|x}) ) + \mathcal{L}(\theta, \phi, \mathbf{x}) \end{aligned} \]
We used the chain rule of probability:
\[ p(\mathbf{x})=\frac{p(\mathbf{x,z})}{p(\mathbf{z|x})} \]
The first term is the KL-divergence between the approximate and the true posterior. Since the KL-divergence is non-negative, the second term is called (variational) lower bound. Ideally, we would like to minimize the both terms. However, it is enough to optimize the lower bound w.r.t both parameters θand φ. Minimizing the lower bound will minimize the KL-divergence as well, since they sum up to a constant value.
The variational lower bound, also called as the evidence lower bound (ELBO) can be also rewritten as:
\[ \begin{aligned} \log p_\theta(\mathbf{x[i]}) &\ge \mathcal{L}(\theta, \phi; \mathbf{x[i]})\\ &= \mathbb{E}_{q_\phi(\mathbf{z|x[i]})} \left[ -\log q_\phi(\mathbf{z|x[i]}) + \log p_\theta(\mathbf{x[i], z}) \right]\\ &= -D_{KL}(q_\phi(\mathbf{z|x[i]})||p_\theta(\mathbf{z})) + \mathbb{E}_{q_\phi(\mathbf{z|x[i]})} \left[ \log p_\theta(\mathbf{x[i]|z}) \right] \end{aligned} \]
\[ \begin{aligned} \mathcal{L}(\theta, \phi; \mathbf{x}) &=\mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{ p_\theta(\mathbf{x, z}) }{ \mathbf{q_\phi(z|x)} } \right]\\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{ p_\theta(\mathbf{x|z})p_\theta(\mathbf{z}) }{ \mathbf{q_\phi(z|x)} } \right]\\ &= \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log \frac{ p_\theta(\mathbf{z}) }{ \mathbf{q_\phi(z|x)} } \right] + \mathbb{E}_{\mathbf{q_\phi(z|x)}} \left[ \log p_\theta(\mathbf{x|z}) \right]\\ &= -D_{KL}(q_\phi(\mathbf{z|x})||p_\theta(\mathbf{z})) + \mathbb{E}_{q_\phi(\mathbf{z|x})} \left[ \log p_\theta(\mathbf{x|z}) \right] \end{aligned} \]
The reparameterization trick
The sampling from \(\mathbf{q_\phi(z|x)}\) is a stochastic process which is not differentiable w.r.t. \(\phi\). We can use an alternative method for generating sample from \(\mathbf{q_\phi(z|x)}\), i.e., the reparameterization trick. We can often express the random variable z as a deterministic variable \(\mathbf{z=g_\phi(\epsilon, x)}\), where \(\epsilon\) is an independent variable and \(\mathbf{g_\phi}\) is a function parameterized by \(\phi\).
The \(\mathbf{q_\phi(z|x)}\) is commonly chosen to model a multivariate Gaussian with diagonal covariance, and the prior often is a standard Gaussian distribution:
\[ \mathbf{ q_\phi(z|x)=\mathcal{N}(z; \mu, \sigma^2) = \mu+\sigma\cdot\epsilon } \]
where \(\mathbf{ \epsilon \sim \mathcal{N}(\epsilon; 0, I) }\) and we can choose \(g_\phi(\epsilon, x)=\mu(x) + \sigma(x) \cdot \epsilon\)
Therefore, by the reparameterization trick, sampling from an arbitrary Gaussian distribution can be performed by sampling from a standard Gaussian, scaling and shifting the result by the target mean and the deviation, which is deterministic and differentiable.
Variational Auto-Encoder
We use a neural network for the probabilistic encoder \(\mathbf{ q_\phi(z|x)}\) and where the parameters \(\phi\) and \(\theta\) are optimized jointly. We also assume that:
- \(p_\theta(\mathbf{z})=\mathcal{N}(\mathbf{z; 0, I})\) - the prior over the latent variables is a standard Gaussian 
- \(p_\theta(\mathbf{x|z})\) is a multivariate Gaussian (in case of real-valued data) or Bernoulli (in case of binary data) 
- \(q_\phi(\mathbf{z|x})\) is approximately Gaussian with an approximately diagonal covariance: \(\log q_\phi(\mathbf{z|x[i]})=\log \mathcal{N}(\mathbf{z; \mu[i], \sigma[i]^2I})\) 
We use the reparameterization trick to sample from the posterior using \(\mathbf{z[i, l]}=g_\phi( \mathbf{x[i], \epsilon[l]} )=\mathbf{\mu[i]+\sigma[i]\cdot \epsilon[l]}\). In this model both \(\mathbf{p_\theta(z)}\) and \(q_\phi(\mathbf{z|x})\) are Gaussian; hence we compute the KL divergence in a closed form without estimation:
\[ \mathcal{L}(\mathbf{\theta, \phi, x[i]}) \simeq \frac{1}{2}\sum_{j=1}^J{\left( 1+\log(\sigma[i][j]^2) - \mu[i][j]^2 - \sigma[i][j]^2 \right)} + \frac{1}{L}\sum_{l=1}^L{ \log p_\theta(\mathbf{x}[i] | \mathbf{z}[i][l]) } \]
In the above equation , only the reconstruction error \(E_{q_\phi(\mathbf{z|x[i]})} \left[ \log p_\theta(\mathbf{x[i]|z})\right]\) requires estimation by sampling, since the KL-divergence term is integrated analytically.
TODO: for now see the KL divergence between two multivariate normal distributions