Trang chủ‎ > ‎IT‎ > ‎Data Science - Python‎ > ‎Tensorflow‎ > ‎

Tensorflow — Dealing with imbalanced data

When you get the imbalanced data in a classification task, there are several thing we can consider:

  1. Collect more data, trying to balance the classes.
  2. If collecting more data is not a option, you can manually select the training set with balanced classes.
  3. If manually balanced the training set is also impossible, you can try to penalize the reward for the majority class. In CNN, that means re-weighting your loss function!

In the following, I will show that how to do loss re-weighting in Tensorflow. In the segmentation task, we can say that we classify each pixel to a certain label for the whole image. If the label is imbalanced and we use Fully-convolutional network (FCN, Take whole image as input and generate full resolution labeled image), we cannot manually select classes like we do in patch-wise approaches. Therefore, we use median-frequency re-weighting which introduced by D. Eigen and R. Fergus (Predicting depth, surface normals and semantic labels with a common multi-scale convolutional architecture.)

where freq(c) is the number of pixels of class c divided by the total number of pixels in images where c is present, and median f req is the median of all class frequencies. Therefore the dominant labels will be assigned with the lowest weight which balances the training process.

Knowing the above concept, we can see that basically we only need to multiply a coefficient according to the label class in our cross-entropy loss.

softmax = tf.nn.softmax(logits)
cross_entropy = -tf.reduce_sum(tf.mul(labels * tf.log(softmax + epsilon), coefficients), reduction_indices=[1])
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')

Noted that because certain class can never exist in a batch, therefore we add a small value epsilon to the softmax to stabilize the cross_entropy. The full code:

def weighted_loss(logits, labels, num_classes, head=None):
with tf.name_scope('loss_1'):
logits = tf.reshape(logits, (-1, num_classes))
epsilon = tf.constant(value=1e-10)
logits = logits + epsilon
# consturct one-hot label array
label_flat = tf.reshape(labels, (-1, 1))
labels = tf.reshape(tf.one_hot(label_flat, depth=num_classes), (-1, num_classes))
softmax = tf.nn.softmax(logits)
cross_entropy = -tf.reduce_sum(tf.mul(labels * tf.log(softmax + epsilon), coefficients), reduction_indices=[1])
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
return loss
Comments