Book Notes: Pattern Recognition and Machine Learning -- Ch10 Variational Inference

For the pdf slides, click here

Variational Inference

Introduction of the variational inference method

Definitions

  • Variational inference is also called variational Bayes, thus
    • all parameters are viewed as random variables, and
    • they will have prior distributions.
  • We denote the set of all latent variables and parameters by Z
    • Note: the parameter vector θ no long appears, because it’s now a part of Z
  • Goal: find approximation for
    • posterior distribution p(ZX), and
    • marginal likelihood p(X), also called the model evidence

Model evidence equals lower bound plus KL divergence

  • Goal: We want to find a distribution q(Z) that approximates the posterior distribution p(ZX). In other word, we want to minimize the KL divergence KL(qp).

  • Note the decomposition of the marginal likelihood logp(X)=L(q)+KL(qp),

  • Thus, maximizing the lower bound (also called ELBO) L(q) is equivalent to minimizing the KL divergence KL(qp). L(q)=q(Z)log{p(X,Z)q(Z)}dZKL(qp)=q(Z)log{p(ZX)q(Z)}dZ

Mean field family

  • Goal: restrict the family of distribution q(Z) so that they comprise only tractable distributions, while allow the family to be sufficiently flexible so that it can approximate the posterior distribution well

  • Mean field family : partition the elements of Z into disjoint groups denoted by Zj, for j=1,,M, and assume q factorizes wrt these groups: q(Z)=j=1Mqj(Zj)
    • Note: we place no resitriction on the functional forms of the individual factors qj(Zj)

Solution for mean field families: derivation

  • We will optimize wrt each qj(Zj) in turn.

  • For qj, the lower bound (to be maximized) can be decomposed as L(q)=kqk{logp(X,Z)klogqk}dZ=qj{logp(X,Z)kjqkdZk}Ekj[logp(X,Z)]dZjqjlogqjdZj+const=KL(qjp~(X,Zj))+const

    • Here the new distribution p~(X,Zj) is defined as logp~(X,Zj)=Ekj[logp(X,Z)]+const

Solution for mean field families

  • A general expression for the optimal solution qj(Zj) is logqj(Zj)=Ekj[logp(X,Z)]+const

    • We can only use this solution in an iterative manner, because the expectations should be computed wrt other factors qk(Zk) for kj.
    • Convergence is guaranteed because bound is convex wrt each factor qj
    • On the right hand side we only need to retain those terms that have some functional dependence on Zj

Example: approximate a bivariate Gaussian using two independent distributions

  • Target distribution: a bivariate Gaussian p(z)=N(zμ,Λ1),μ=(μ1μ2),Λ=(λ11λ12λ12λ22)

  • We use a factorized form to approximate p(z): q(z)=q1(z1)q2(z2)

  • Note: we do not assume any functional forms for q1 and q2

VI solution to the bivariate Gaussian problem

logq1(z1)=Ez2[logp(z)]+const=Ez2[12(z1μ1)2Λ11(z1μ1)Λ12(z2μ2)]+const=12z12Λ11+z1μ1Λ11(z1μ1)Λ12(E[z2]μ2)+const

  • Thus we identify a normal, with mean depending on E[z2]: q(z1)=N(z1m1,Λ111),m1=μ1Λ111Λ12(E[z2]μ2)

  • By symmetry, q(z2) is also normal; its mean depends on E[z1] q(z2)=N(z2m2,Λ221),m2=μ2Λ221Λ12(E[z1]μ1)

  • We treat the above variational solutions as re-estimation equations, and cycle through the variables in turn updating them until some convergence criterion is satisfied

Visualize VI solution to bivariate Gaussian

  • Variational inference minimizes KL(qp): mean of the approximation is correct, but variance (along the orthogonal direction) is significantly under-estimated

  • Expectation propagation minimizes KL(pq): solution equals marginal distributions

Left: variational inference. Right: expectation propagation

Figure 1: Left: variational inference. Right: expectation propagation

Another example to compare KL(qp) and KL(pq)

  • To approximate a mixture of two Gaussians p (blue contour)
  • Use a single Gaussian q (red contour) to approximate p
    • By minimizing KL(pq): figure (a)
    • By minimizing KL(qp): figure (b) and (c) show two local minimum

  • For multimodal distribution
    • a variational solution will tend to find one of the modes,
    • but an expectation propagation solution would lead to poor predictive distribution (because the average of the two good parameter values is typically itself not a good parameter value)

Example: univariate Gaussian

Example: univariate Gaussian

  • Suppose the data D={x1,,xN} follows iid normal distribution xiN(μ,τ1)

  • The prior distributions are μτN(μ0,(λ0τ)1)τGam(a0,b0)

  • Factorized variational approximation q(μ,τ)=q(μ)q(τ)

Variational solution for μ

logq(μ)=Eτ[logp(Dμ,τ)+logp(μτ)]+const=E[τ]2{λ0(μμ0)2+i=1N(xiμ)2}+const

Thus, the variational solution for μ is q(μ)=N(μμN,λN1)μN=λ0μ0+Nx¯λ0+NλN=(λ0+N)E[τ]

Variational solution for τ

logq(τ)=Eμ[logp(Dμ,τ)+logp(μτ)+logp(τ)]+const=(a01)logτb0τ+N2logτ    τ2Eμ[λ0(μμ0)2+i=1N(xiμ)2]+const

Thus, the variational solution for τ is q(τ)=Gam(τaN,bN)aN=a0++N2bN=b0+12Eμ[λ0(μμ0)2+i=1N(xiμ)2]

Visualization of VI solution to univariate normal

Model selection

Model selection (comparison) under variational inference

  • In addition to making inference on the parameter Z, we may also want to compare a set of candidate models, denoted by index m

  • We should consider the factorization q(Z,m)=q(Zm)q(m) to approximate the posterior p(Z,mX)

  • We can maximize the information lower bound Lm=mZq(Zm)q(m)log{p(Z,X,m)q(Zm)q(m)} which is a lower bound of logp(X)

  • The maximized q(m) can be used for model selection

Variational Mixture of Gaussians

Mixture of Gaussians

  • For each observation xnRD, we have a corresponding latent variable zn, a 1-of-K binary group indicator vector

  • Mixture of Gasussians joint likelihood, based on N observations p(Zπ)=n=1Nk=1Kπkznkp(XZ,μ,Λ)=n=1Nk=1KN(xnμk,Λk1)znk

Graph representation of mixture of Gaussians

Figure 2: Graph representation of mixture of Gaussians

Conjugate priors

  • Dirichlet for π p(π)=Dir(πα0)k=1Kπkα0k1

  • Independent Gaussian-Wishart for μ,Λ p(μ,Λ)=k=1Kp(μkΛk)p(Λk)=k=1KN(μkm0,(β0Λk)1)W(ΛkW0,ν0)

    • Usually, the prior mean m0=0

Variational distribution

  • Joint posterior p(X,Z,π,μ,Λ)=p(XZ,μ,Λ)p(Zπ)p(π)p(μΛ)p(Λ)

  • Variational distribution factorizes between the latent variables and the parameters q(Z,π,μ,Λ)=q(Z)q(π,μ,Λ)=q(Z)q(π)k=1Kq(μk,Λk)

Variational solution for Z

  • Optimized factor logq(Z)=Eπ,μ,Λ[logp(X,Z,π,μ,Λ)]=Eπ[logp(Zπ)]+Eμ,Λ[logp(XZ,μ,Λ)]=n=1Nk=1Kznklogρnk+constlogρnk= E[logπk]+12E[log|Λk|]D2log(2π)12Eμ,Λ[(xnμk)Λk(xnμk)]

  • Thus, the factor q(Z) takes the same functional form as the prior p(Zπ) q(Z)=n=1Nk=1Krnkznk,rnk=ρnkj=1Kρnj

    • By q(Z), the posterior mean (i.e., responsibility) E[znk]=rnk

Define three statistics wrt the responsibilities

  • For each of group k=1,,K, denote Nk=n=1Nrnkx¯k=1Nkn=1NrnkxnSk=1Nkn=1Nrnk(xnx¯k)(xnx¯k)

Variational solution for π

  • Optimized factor logq(π)=logp(π)+EZ[p(Zπ)]=(α01)k=1Klogπk+k=1Kn=1Nrnklogπnk+const

  • Thus, q(π) is a Dirichlet distribution q(π)=Dir(α),αk=α0+Nk

Variational solution for μk,Λk

  • Optimized factor for (μk,Λk) logq(μk,Λk)= EZ[n=1NznklogN(xnμk,Λk1)]+logp(μkΛk)+logp(Λk)

  • Thus, q(μk,Λk) is Gaussian-Wishart q(μkΛk)=N(mk,(βkΛk)1)q(Λk)=W(ΛkWk,νk)

  • Parameters are updated by the data βk=β0+Nk,mk=1βk(β0m0+Nkx¯k),νk=ν0+NkWk1=W01+NkSk+β0Nkβ0+Nk(x¯km0)(x¯km0)

Similarity between VI and EM solutions

  • Optimization of the variational posterior distribution involves cycling between two stages analogous to the E and M steps of the maximum likelihood EM algorithm

    • Finding q(Z): analogous to the E step, both need to compute the responsibilities
    • Finding q(π,μ,Λ): analogous to the M step
  • The VI solution (Bayesian approach) has little computational overhead, comparing with the EM solution (maximum likelihood approach). The dominant computational cost for VI are

    • Evaluation of the responsibilities
    • Evaluation and inversion of the weighted data covariance matrices

Advantage of the VI solution over the EM solution:

  • Since our priors are conjugate, the variational posterior distributions have the same functional form as the priors
  1. No singularity arises in maximum likelihood when a Gassuain component “collapses” onto a specific data point

    • This is actually the advantage of Bayesian solutions (with priors) over frequentist ones
  2. No overfitting if we choose a large number K. This is helpful in determining the optimal number of components without performing cross validation

    • For α0<1, the prior favors soutions where some of the mixing coefficients π are zero, thus can result in some less than K number components having nonzero mixing coefficients

Computing variational lower bound

  • To test for convergence, it is useful to monitor the bound during the re-estimation.
  • At each step of the iterative re-estimation, the value of the lower bound should not decrease L= Zq(Z,π,μ,Λ)log{p(X,Z,π,μ,Λ)q(Z,π,μ,Λ)}dπdμdΛ= E[logp(X,Z,π,μ,Λ)]E[logq(Z,π,μ,Λ)]= E[logp(XZ,μ,Λ)]+E[logp(Zπ)]+E[logp(π)]+E[logp(μ,Λ)]E[logq(Z)]E[logq(π)]E[logq(μ,Λ)]

Label switching problem

  • EM solution of maximum likelihood does not have label switching problem, because the initialization will lead to just one of the solutions

  • In a Bayesian setting, label switching problem can be an issue, because the marginal posterior is multi-modal.

  • Recall that for multi-modal posteriors, variational inference usually approximate the distribution in the neighborhood of one of the modes and ignore the others

Induced factorizations

  • Induced factorizations: the additional factorizations that are a consequence of the interaction between

    • the assumed factorization, and
    • the conditional independence properties of the true distribution
  • For example, suppose we have three variation groups A,B,C

    • We assume the following factorization q(A,B,C)=q(A,B)q(C)
    • If A and B are conditional independent ABX,Cp(A,BX,C)=p(AX,C)p(BX,C) then we have induced factorization q(A,B)=q(A)q(B) logq(A,B)=EC[logp(A,BX,C)]+const=EC[logp(AX,C)]+EC[logp(BX,C)]+const

Variational Linear Regression

Bayesian linear regression

  • Here, I use a denotion system commonly used in statistics textbooks. So its different from the one used in this book.

  • Likelihood function p(yβ)=n=1NN(ynxnβ,ϕ1)

    • ϕ=1/σ2 is the precision parameter. We assume that it is known.
    • βRp includes the intercept
  • Prior distributions: Normal Gamma p(βκ)=N(β0,κ1I)p(κ)=Gam(κa0,b0)

Variational solution for κ

  • Variational posterior factorization q(β,κ)=q(β)q(κ)

  • Varitional solution for κ logq(κ)=logp(κ)+Eβ[logp(βκ)]=(a01)logκb0κ+p2logκκ2E[ββ]

  • Varitional posterior is a Gamma κGam(aN,bN)aN=a0+p2bN=b0+E[ββ]2

Variational solution for β

  • Variational solution for β logq(β)=logp(yβ)+Eκ[logp(βκ)]=ϕ2(yXβ)2E[κ]2ββ=12β(E[κ]I+ϕXX)β+ϕβXy

  • Variational posterior is a Normal βN(mN,SN)SN=(E[κ]I+ϕXX)1mN=ϕSNXy

Iteratively re-estimate the variational solutions

  • Means of the variational posteriors E[κ]=aNbNE[ββ]=mNmN+SN

  • Lower bound of logp(y) can be used in convergence monitoring, and also model selection L= E[logp(β,κ,y)]E[logq(β,κ)]= Eβ[logp(yβ)]+Eβ,κ[logp(βκ)]+Eκ[logp(κ)]Eβ[logq(β)]Eκ[logq(κ)]

TO BE CONTINUED

Exponential Family Distributions

Local Variational Methods

Variational Logistic Regression

Expectation Propagation

References

  • Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer.