Skip to main content

Probabilistic Graphical Models

We introduce the concept of probabilistic graphical models (PGMs) as a probabilistic model for representing the conditional dependence structure between random variables. Some of the most common PGMs are Markov Random Fields and Bayesian Networks

We would like to explore inference of PGMs, let says:

  • xE=x_E = The observed evidence
  • xF=x_F = The unobserved variables we want to infer
  • xR=x{xE,xF}=x_R = x - \{x_E, x_F\} = Remaining variables, extraneous to query

Then we have p(xFxE)=p(xF,xE)p(xE)=p(xF,xE)xFp(xF,xE)p(x_F | x_E) = \frac{p(x_F, x_E)}{p(x_E)} = \frac{p(x_F, x_E)}{\sum_{x_F}p(x_F, x_E)} and the p(xF,xE)=xRp(xF,xE,xR)p(x_F, x_E) = \sum_{x_R}p(x_F, x_E, x_R) by marginalization.

We also define p(x1:T)=t=1Tp(xtxt1,,x1)p(x_{1:T}) = \prod_{t=1}^Tp(x_t|x_{t-1}, \dots, x_1) where p(x1x0)=p(x1)p(x_1|x_0) = p(x_1)

Variable Elimination

We also have some other ways to do marginalization and the ways we do marginalization affects the computational costs. Such tools called variable elimination

  • A simple and general exact inference algorithm in any PGM (e.g. MRFs, BNs, etc.)
  • Dynamic programming avoids enumerating all variables assignments

VE applied to Trees

Recall a graph G=(V,E)G = (V,E), the joint distribution p(x1:n)=iVψ(xi)(i,j)Eψ(xi,xj)p(x_{1:n}) = \prod_{i\in V}\psi(x_i)\prod_{(i,j)\in E}\psi(x_i, x_j).

We define the message sent from jj to iN(j)i\in N(j) is mji(xi)=xjψj(xj)ψij(xi,xj)kN(j){i}mkj(xj)m_{j\to i}(x_i)=\sum_{x_j}\psi_j(x_j)\psi_{ij}(x_i, x_j)\prod_{k\in N(j)\setminus\{i\}}m_{k\to j}(x_j). If xjx_j is observed, the message is mji(xi)=ψj(xj)ψij(xi,xj)kN(j){i}mkj(xj)m_{j\to i}(x_i) = \psi_j(x_j)\psi_{ij}(x_i, x_j)\prod_{k\in N(j)\setminus\{i\}}m_{k\to j}(x_j).

Then we define the belief as b(xi)ψi(xi)jN(i)mji(xi)b(x_i) \propto \psi_i(x_i)\prod_{j\in N(i)}m_{j\to i}(x_i).

Once normalized, beliefs are the marginals we want to compute.

Belief Propagation

We define Belief Propagation as a message-passing between neighboring vertices of the graph. We have belief propagation algorithm as follows:

  1. Choose root rr arbitrarily
  2. Pass messages from leaves to rr
  3. Pass messages from rr to leaves
  4. These two passes are sufficient on trees
  5. Compute beliefs b(xi)b(x_i)

Or compute them in two steps:

  1. Compute unnormalized beliefs b~(xi)=ψi(xi)jN(i)mji(xi)\tilde{b}(x_i) = \psi_i(x_i)\prod_{j\in N(i)}m_{j\to i}(x_i)
  2. Normalize beliefs b(xi)=b~(xi)xib~(xi)b(x_i) = \frac{\tilde{b}(x_i)}{\sum_{x_i}\tilde{b}(x_i)}

VE applied to MRFs and BNs

Before talking such algorithms for MRFs and BNs, we would like to define following terms:

  • Introduce nonnegative factor ϕ\phi
  • Marginalizing over XX we introduce a new factor, denote by τ\tau

Then we would like to define the sum-product algorithm where p(xFxE)τ(xF,xE)=xRCFϕC(xC)p(x_F | x_E) \propto \tau(x_F, x_E) = \sum_{x_R}\prod_{C\in \mathcal{F}}\phi_C(x_C) where F\mathcal{F} is the set of potentials or factors.

  • For DAGs(BNs), F\mathcal{F} is the set of the forms {i}Parent(i),i\{i\} \cup \text{Parent}(i), \forall i
  • For MRFs, F\mathcal{F} is given by the set of all maximal cliques

The complexity of the Variable Elimination algorithm is O(mkNmax)O(mk^{N_{\text{max}}})

  • mm is the number of initial factors (i.e. m=Fm = |\mathcal{F}|)
  • kk is the number of states each random variable takes (assumed to be equal here)
  • NiN_i is the number of random variables inside each sum i\sum_i
  • Nmax=maxiNiN_{\text{max}} = \max_iN_i is the maximum number of random variables inside the largest sum i\sum_i

Belief Propagation on MRFs

If such graph we want to compute is not a tree and have cycles, we would like to keep passing messages until convergence which called Loopy Belief Propagation. (but result is an approximation to exact marginal)

Loopy BP algorithm (may not converge):

  1. Initialize messages uniformly: mij(xj)=[1/k,,1/k]Tm_{i\to j}(x_j) = [1/k, \dots, 1/k]^T
  2. Keep running BP updates until it converges: mji(xi)=xjψj(xj)ψij(xi,xj)kN(j){i}mkj(xj)m_{j\to i}(x_i) = \sum_{x_j}\psi_j(x_j)\psi_{ij}(x_i, x_j)\prod_{k\in N(j)\setminus\{i\}}m_{k\to j}(x_j) and normalize for stability
  3. Compute beliefs b(xi)ψi(xi)jN(i)mji(xi)b(x_i) \propto \psi_i(x_i)\prod_{j\in N(i)}m_{j\to i}(x_i)

MAP Inference over BP

We update BP take the form mji(xi)=maxxjψj(xj)ψij(xi,xj)kN(j){i}mkj(xj)m_{j \to i}(x_i) = \max_{x_j}\psi_j(x_j)\psi_{ij}(x_i, x_j)\prod_{k\in N(j)\setminus\{i\}}m_{k\to j}(x_j), and get he beliefs (max-marginals) b(xi)ψi(xi)jN(i)mji(xi)b(x_i) \propto \psi_i(x_i)\prod_{j\in N(i)}m_{j\to i}(x_i). The MAP inference is x^i=argmaxxib(xi)\hat{x}_i = \arg\max_{x_i}b(x_i).