batch normalization

Batch Normalization

Batch normalization is another effective way to generalize a model.

Assuming a mini-batch of m samples X1, X2, .. Xm, the mean u and variance of m is

   u = (X1 + X2... + Xm)/m

   var = sum((Xi - u)^2)/m

Every sample Xi = (xi1, xi2, .. xik) with k dimensions, calculate the u and var for every dimension.

Assume the u and var of the kth dimension is uk and vark

Then normalize every xik

   xik' = (xik - uk) / sqrt(vark + epsilon)

   i.e.

   xik' = (xik - uk) / standard deviation

Where epsilon is a very small constant to avoid zero denominator.

Here xik is normalized into 0 mean and unit variance distribution.

To restore the representation power of a network, a transformation is needed to further transform xik to

   yik' = gamma_k * xik' + beta_k

   

Imaging the gamma = standard deviation sqrt(vark) and beta = mean uk, then gamma * xik' + beta reverses the previous normalization and yik' = xik.

The gamma and beta here are actually new parameters for the network to learn. So the Batch Norm layer introduces 2 more parameters per dimension.

# of parameters

Given a fully connected neural network, every neuron's output value is normalized. So the mean and variance of every neuron's outputs 

are calculated seprately for a batch. If a layer has k neurons, then the batch norm normalizes the k nerurons separately and will have k * 2 parameters, k gammas and k betas.

For a convolutional network, however, the normalization is applied at channel level instead of neuron level. 

E.g. a conv layer outputs to 12 channels and every channel outputs a M x N feature map. As every element in the MxN feature map are derived from the same set of filters, 

all elements share the same filter weights and bias (which scan through an image). Elements of the whole feature map share the same data distribution. 

Thus normalization only needs to apply at channel level.

  the mean is the average of all elements' outputs

  the var is the var of all elements's outputs.

  then use the same mean and var to normalize all elements of a channel's output(ie. a feature map)

therefore, batch norm has #channels * 2 parameters.

Why useful

Why batch norm works well is still a myth. The original article says batch norm controls the "internal covariate shift" but it has been proven to be wrong.

In experiments that covariate shift is deliberately added to a network after batch norm, the result is still good and better than a network without batch norm.

Another paper says it smooths the objective function so it makes it easier for SGD to work and find a better solution.

As batch norm normalizes a layer's output, it avoids vanishing / exploding gradients. (it always re-align values to a mean-0 normal distribution)

Imagining sigmoid function values, when the value is too far from 0, the gradient tends to be zero, and it becomes even smaller after propagating through a few layers.

More specific explanation of batch norm on fully connected network and cnn network

Firstly the output of the convolutional layer is a 4-rank tensor [B, H, W, C], where B is the batch size, (H, W) is the feature map size, 

C is the number of channels. An index (x, y) where 0 <= x < H and 0 <= y < W is a spatial location.

Usual batchnorm (for fully connected network)

# t is the incoming tensor of shape [B, H, W, C]

# mean and stddev are computed along 0 axis and have shape [H, W, C]

# so every element in [H, W, C] needs to be normalized seprately

mean = mean(t, axis=0)

stddev = stddev(t, axis=0)

for i in 0..B-1:

  out[i,:,:,:] = norm(t[i,:,:,:], mean, stddev)

Basically, it computes H*W*C means and H*W*C standard deviations across B elements. 

You may notice that different elements at different spatial locations have their own mean and variance and gather only B values.

Batchnorm in convolutional layer

As mentioned the convolutional layer has a special property: filter weights are shared across the input image. 

That's why it's reasonable to normalize the whole channel's output in the same way, so that each element of a feature map takes the mean 

and variance of B*H*W values (all elements of a channel), at different locations.

Here's how the code looks like in this case (again pseudo-code):

# t is still the incoming tensor of shape [B, H, W, C]

# but mean and stddev are computed along (0, 1, 2) axes and have just [C] shape

mean = mean(t, axis=(0, 1, 2))

stddev = stddev(t, axis=(0, 1, 2))

for i in 0..B-1, x in 0..H-1, y in 0..W-1:

  out[i,x,y,:] = norm(t[i,x,y,:], mean, stddev)

In total, there are only C means and standard deviations and each one of them is computed over B*H*W values. 

Problems

Small Batch Size  If the batch size is1, the variance is 0 which doesn’t work. Basically for small batch sizes, the variance and mean are too noisy. 

Recurrent Neural Network  In an RNN, every step the variance and mean are different to that of the previous step, because new knowledge has accumulated, so we can't reuse the same batch norm layer. We have to fit a separate batch norm layer for each time-step, which is too expensive.