Sum-Product Algorithm: MATLAB code

I coded up the sum-product algorithm in MATLAB and make it available for free for non-commercial purpose. The reason behind this code is that I found it not very convenient to use BNT toolbox in my ITSBN work, though BNT is a great tool for inference on graphical model in general. However, ITSBN can be slow significantly when implemented using ITSBN, so I thought it might be a good idea to develop my own code where I can have controls for everywhere in the code. This page serves as an instruction and a code documentation including some explanations of the variables and function in the code.

How to use the sum-product toolbox?

Before using the toolbox, it is a very good idea to know what are the factor graphs and sum-product algorithm as this will be very helpful for modeling a system with a factor graph. I made a very rough tutorial [pdf] for sum-product algorithm--how to use it and how it is implemented using MATLAB. However, this is very crude and is intended to use as a supplemental material to the original text.

In this example, there are 2 types of random variable involved: 1) discrete and 2) continuous (multivariate Gaussian). In particular, this toolbox is intentionally created for using with my ITSBN whose observed nodes takes multivariate Gaussian distribution whereas its hidden nodes are discrete representing class label of a superpixel.

The code can be seen as 3 portions:

Part I: User input all the structure of the network and its parameters to the network

% add path to the sum-product algorithm toolbox

addpath('./sum_product_toolbox');

% #################################################################

% ################ USER-DEFINED Parameters ########################

% #################################################################

N = 6;

% Note that all the node mentioned in this section is x

discrete_node_list = [1:3]; % index of the discrete nodes

continuous_node_list = [4:6]; % index of the Gaussian nodes

hidden_node_list = [1:3]; % hidden node list

observed_node_list = [4:6]; % observed node list

size_node_list = [2 3 3, 3 3 3]; % # of class in the discrete variable nodes or dimension of the continuous variable

Above, we define which node in the network has what attribute.

% parameters for factor node f

CPT = cell(N,1);

mu = cell(N,1);

Sigma = cell(N,1);

value = cell(N,1);

% ==== input CPT parameters =====

CPT{1,1} = [1 1]';

CPT{2,1} = [0.8 0.1;0.1 0.45;0.1 0.45];

CPT{3,1} = CPT{2,1};

The CPT is a conditional probability table connecting a child node with a parent node, therefore, its dimension is |#class of child| x |#class of parent|.The CPT can be defined for each individual node taking discrete class value.

% mu{i} = C x D

mu{4,1} = [0 0 0;7.5 7.5 7.5];

mu{5,1} = [0 0 0;5 5 5;10 10 10];

mu{6,1} = mu{5,1};

mu is a cell array of size N x 1, where N is the total number of nodes in the network. Each entry mu{n,1} is C x D matrix whose C is the number of Gaussian component, and D denotes the dimensionality of the data.

% Sigma{i} = D x D x C

S = zeros(3,3,2); S(:,:,1) = 10*eye(3); S(:,:,2) = 10*eye(3);

Sigma{4,1} = S;

S = zeros(3,3,3); S(:,:,1) = 10*eye(3); S(:,:,2) = 10*eye(3); S(:,:,3) = 10*eye(3);

Sigma{5,1} = S;

Sigma{6,1} = S;

Sigma is a cell array of size N x 1. Each entry Sigma{n,1} is a D x D x C matrix whose C is the number of Gaussian component. Note that mu and Sigma here follows the format of Gaussian mixture model (GMM) object in MATLAB.

% Value (evidence)

value{4,1} = [5 5 5];

value{5,1} = [0 0 0];

value{6,1} = [10 10 10];

Again, value is also a cell array of size N x 1. Each entry value{n,1} represents the instantiated value at that particular node n which can be vector or scalar depending on the data.

% Define the connectivity: child - parent

connectivity = [2 1

3 1

4 1

5 2

6 3];

connectivity, the N-1 x 2 matrix, captures the user-defined structure of the network. The first column represents the child node connecting to its corresponding parent listed in the second column. Note that we don't regard the root node (node 1) as a child in this matrix, hence, we don't have [1 0] on the first row of this matrix. If you do so, the code will give unexpected results, so be careful!

% #################################################################

% #################################################################

% #################################################################

Now we finish the first part which is about the user preparing all the information and parameters for the network. In the next part, the code will collect all the input information and build a factor graph and sum-product algorithm engine to perform sum-product algorithm. Also be clear that user are not supposed to change or input anything beyond this point.

Part II: Creating the factor graphs from the user-defined information

%% #################################################################

% ####### Build the graphical model in MATLAB ##########

% #################################################################

script_build_factor_graph;

The code builds the MATLAB struct array to capture the data structure of variable node x and the factor node f.

%% ############## Start SUM-Product Algorithm #####################

script_sum_product_algorithm

Here the code perform sum-product algorithm to the defined network. After this process, all the messages in the network are calculated already, and in the next process we can calculate the marginal posterior at each hidden node if we wish to. In fact, users can choose not to use the sum-product algorithm and perform message passing manually step by step.

Part III: The final step is to infer the marginal posterior distribution at any hidden node of interest. Joint posterior distribution is also doable too, but only when the two nodes have a common factor node.

%% Calculate the marginal posterior

[true_marg_joint, marg_post, ll] = fn_cal_marg(f,x,1)

[true_marg_joint, marg_post, ll] = fn_cal_marg(f,x,2)

[true_marg_joint, marg_post, ll] = fn_cal_marg(f,x,3)

% #####################################

% Calculate the joint

joint_post = fn_cal_joint(f,x,1,3)

joint_post = fn_cal_joint(f,x,3,1)

You might want to try running the code. So, download the code [.rar] and run the example1.m.

There are several types of messages involved in this code

All the messages are stored only in the factor nodes, and not in the variable nodes.

  • Messages transmitted from or absorbed into a hidden discrete node are column vector
  • Messages transmitted from an observed discrete node are Kronecker delta function represented by zero column vector with one at the observed class label
  • Messages transmitted from an observed continuous node are Dirac delta function represented by the observed value
  • Messages absorbed to an observed continuous node are GMM object

The structure of variable nodes and factor nodes

The variable nodes and factor nodes are implemented using struct as x and f respectively. The fields of each type are described in detail below:

  • field of variable node x
    • field of variable node f