Tutorial: Variational Bayes on Bayesian Gaussian density estimation

3 minute read

This notebook is based on description of variational Bayes entry on wikipedia.

We place conjugate prior distributions on the unknown mean and variance, i.e. the mean also follows a Gaussian distribution while the precision follows a gamma distribution. In other words:

We are given $N$ data points $\mathbf{X} = {x_{1}, \dots, x_N}$ and our goal is to infer the posterior distribution $q(\mu, \tau)=p(\mu,\tau\mid x_{1}, \ldots, x_N)$ of the parameters $\mu$ and $\tau$.

The hyperparameters $\mu_0$, $\lambda_0$, $a_0$ and $b_0$ are fixed, given values. They can be set to small positive numbers to give broad prior distributions indicating ignorance about the prior distributions of $\mu$ and $tau$.

The joint probability

The joint probability of all variables can be rewritten as

where the individual factors are

where

Factorized approximation

Assume that $q(\mu,\tau) = q(\mu)q(\tau)$, i.e. that the posterior distribution factorizes into independent factors for $\mu$ and $\tau$. This type of assumption underlies the variational Bayesian method. The true posterior distribution does not in fact factor this way (in fact, in this simple case, it is known to be a Gaussian-gamma distribution), and hence the result we obtain will be an approximation.

Derivation of q($\mu$)

Then

In the above derivation, $C$, $C_2$ and $C_3$ refer to values that are constant with respect to $\mu$. Note that the term $\operatorname{E}_{\tau}[\ln p(\tau)]$ is not a function of $\mu$ and will have the same value regardless of the value of $\mu$. Hence in line 3 we can absorb it into the constant term at the end. We do the same thing in line 7.

The last line is simply a quadratic polynomial in $\mu$. Since this is the logarithm of $q_\mu^+(\mu)$, we can see that $q_\mu^+(\mu)$ itself is a Gaussian distribution.

With a certain amount of tedious math (expanding the squares inside of the braces, separating out and grouping the terms involving $\mu$ and $\mu^2$ and completing the square over $\mu$), we can derive the parameters of the Gaussian distribution:

Note that all of the above steps can be shortened by using the formula for the sum of two quadratics.

In other words:

Derivation of q($\tau$)

The derivation of $q_\tau^+(\tau)$ is similar to above, although we omit some of the details for the sake of brevity.

Exponentiating both sides, we can see that $q_\tau^+(\tau)$ is a gamma distribution. Specifically:

Algorithm for computing the parameters

Let us recap the conclusions from the previous sections:

and

In each case, the parameters for the distribution over one of the variables depend on expectations taken with respect to the other variable. We can expand the expectations, using the standard formulas for the expectations of moments of the Gaussian and gamma distributions:

Applying these formulas to the above equations is trivial in most cases, but the equation for $b_N$ takes more work:

We can then write the parameter equations as follows, without any expectations:

Note that there are circular dependencies among the formulas for $\mu_N$, $\lambda_N$ and $b_N$. This naturally suggests an EM-like algorithm:

  • Compute $\sum_{n=1}^N x_n$ and $\sum_{n=1}^N x_n^2.$ Use these values to compute $\mu_N$ and $a_N.$
  • Initialize $\lambda_N$ to some arbitrary value.
  • Use the current value of $\lambda_N,$ along with the known values of the other parameters, to compute $b_N$.
  • Use the current value of $b_N,$ along with the known values of the other parameters, to compute $\lambda_N$.
  • Repeat the last two steps until convergence (i.e. until neither value has changed more than some small amount).
1
2
3
4
5
6
7
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import seaborn as sns
sns.set(style="darkgrid")
sns.set_context("notebook")
1
2
3
4
5
true_mu = 130
true_tau = 0.01
N = 1000000
samples = random.normal(true_mu,1/true_tau,N)
_ = plt.hist(samples)

png

1
2
from scipy.stats import norm
import matplotlib.mlab as mlab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
a_0 = 100
b_0 = 20

lambda_0 = 100
mu_0 = -100

x_ = np.sum(samples)
x__ = np.sum(samples**2)

def VB(samples,lambda_N):
global x_,x__
N = samples.shape[0]
mu_N = (lambda_0*mu_0+x_)/(lambda_0+N)
a_N = a_0*(N+1)/2.0
b_N = b_0+(1.0/2)*((lambda_0+N)*(1.0/lambda_N+mu_N**2)-2.0*(lambda_0*mu_0+x_)*mu_N+x__+lambda_0*mu_0**2)
lambda_N_new = (lambda_0+N) * a_N/b_N
return lambda_N_new,a_N,b_N


plt.hist(samples,weights=np.repeat(1.0/samples.shape[0], samples.shape[0]),normed=True)
plt.plot(np.arange(-1000,1000,1)+true_mu,mlab.normpdf(np.arange(-1000,1000,1)+true_mu,true_mu,1.0/true_tau),'b',
linewidth=2)
lamb = 0.001
a_N,b_N = 0,0
for i in range(5):
lamb_,a_N,b_N = VB(samples, lamb)
mu_N = (lambda_0*mu_0+x_)/(lambda_0+N)

print("Tau",a_N/b_N)
plt.plot(np.arange(-1000,1000,1)+true_mu,norm.pdf(np.arange(-1000,1000,1)+true_mu,mu_N,b_N/a_N),linewidth=2)
lamb = lamb_
Tau 0.00908505472075
Tau 0.00999301541902
Tau 0.00999301542901
Tau 0.00999301542901
Tau 0.00999301542901

png