Recall the posterior distribution p(z∣x)=p(x)p(x,z) is the distribution of the latent variables given the observed data where p(x)=∫p(x,z)dz is the marginal distribution of the observed data. But generally, when we face high dimensional latent variables, it becomes intractable to compute the posterior distribution. Specifically, we have the following problem:
- computing a posterior distribution p(z∣x) is intractable
- computing the evidence/likelihood p(x) is intractable
- computing marginal distributions p(z) is intractable
- sampling from the posterior distribution p(z∣x) is intractable
Variational Inference is an approximate inference method that allows us to approximate the posterior distribution p(z∣x) by a simpler distribution q(z). It works follow:
- Choose a tractable distribution q(z)∈Q from a feasible set Q to approximate the posterior distribution p(z∣x).
- q(z)=N(z∣μ,Σ) sometimes a good choice
- Encode some notion of difference between p(z∣x) and q that can be efficiently estimated.
- KL divergence is a good choice
- Minimize the difference
- usually by iterative optimization
KL Divergence
The KL (Kullback-Leibler) divergence is a measure of the difference between two probability distributions p and q. It is defined as:KL(q(z)∣∣p(z∣x))=∫q(z)logp(z∣x)q(z)dz=Ez∼qlogp(z∣x)q(z)
Some properties of KL divergence:
- KL(q∣∣p)≥0
- KL(q∣∣p)=0 if and only if q=p
- KL(q∣∣p)=KL(p∣∣q)
- KL divergence is not symmetric so that it is not a metric
I-projection: q∗=argminq∈QKL(q∣∣p)=Ex∼qlogp(x)q(x). It follows:
- p≈q⟹KL(q∣∣p)≈0
- I-projection underestimates support, and does not yield the correct moments.
- KL(q∣∣p) penalizes q having mass where p has none.
Moment Projection
M-projection: q∗=argminq∈QKL(q∣∣p)=Ex∼qlogq(x)p(x). It follows:
- p≈q⟹KL(p∣∣q)≈0
- KL(p∣∣q) penalizes q having mass where p has none.
- M-projection yields a distribution q(x) with the correct mean and covariance.
ELBO: Evidence Lower Bound
Sometimes evaluating the KL divergence is intractable due to the integral over z and the term p(z∣x) is intractable to normalize. That is, we would like maximize the ELBO (which is the same as minimizing the KL divergence).
KL(q(z)∣∣p(z∣x))=Ez∼qlogp(z∣x)q(z)=Ez∼q[log(p(z∣x)q(z)p(x))]=Ez∼qlog[p(z∣x)q(z)]+Ez∼qlogp(x)
We denote ELBO L(ϕ)=Ez∼qlog[q(z)p(z∣x)], that is, KL(q(z)∣∣p(z∣x))=−L(ϕ)+logp(x).
Since KL(q(z)∣∣p(z∣x))≥0, then L(ϕ)≤logp(x).
Since it's a min/max problem, we would like to use SGD/GD to solve it. If z satisfies some distribution, then we can use sampling method to solve the optimization problem.
Interpretaion
We define entropy as H(p)=−Ex∼plogp(x) to measure the uncertainty of a distribution p.
That is, we have the optimization problem: maximize H(p) subject to Ex∼p(x)[fi(x)]=ti for i=1,2,⋯,k.
Exponential Family yields Maximum Entropy
We have the theorem that Exponential Family of distributions maximize the entropy over all distributions satisfying Ex∼p(x)[fi(x)]=ti for i=1,2,⋯,k.
That is, if Q is set of exponential family, then the expected sufficient statistics w.r.t q∗(x) is the same as that w.r.t p(x) in M-projection. Since required the p(x), then M-projection is intractable. Then most variational inference methods use I-projection.
Example on MRF
We have the following MRF: p(x∣θ)=exp{∑c∈Cϕc(xc)−logZ(θ)}
We use I-projection to approximate the posterior distribution p(x∣θ) by a simpler distribution q(x). Then we have the following optimization problem:
q∗=argminq∈QKL(q∣∣p)=Ex∼qlogp(x∣θ)q(x)=argminq∈QEx∼q[logq(x)−∑c∈Cϕc(xc)+logZ(θ)]=argmaxq∈QH(q)+∑c∈CEq[ϕc(xc)]
Ideally, if p∈Q, then q∗=p. But not always happens so we use mean-field approach.
Mean-Field Approach
We firstly assume q(x)=∏i∈Vqi(xi), the set Q is composed of those distributions that factor out.
Using I-projection above, we have q∗=argmaxq∈QH(q)+∑c∈C∑xcq(xc)ϕc(xc)
Since H(q)=−Ex∼qlogq(x)=−∑xq(x)[∑ilogqi(xi)]=−∑i∑x[logqi(xi)qi(xi)]qi(xi)q(x)=−∑i∑xi[logqi(xi)qi(xi)]∑x∖xiqi(xi)q(x)=−∑i∑xi[logqi(xi)qi(xi)]=∑iH(qi), then we have q∗=argmaxq∈Q∑iH(qi)+∑c∈C∑xcq(xc)ϕc(xc) where ∑xiqi(xi)=1
But it's still not easy to solve. We can use coordinate ascent to solve it.
- Initialize {qi(xi)}i∈V
- Iterate over i∈V
- Greedy maximize the objective over qi(xi) (i.e.) qi(xi)∝exp{∑j∈N(i)∑xjqj(xj)ϕij(xi,xj)} where we can use Lagrangian by setting derivative to zero to solve
- Repeat until convergence
This is guaranteed to converge but can only converge to local optima.