pack variable-length sequences

Pack_Padded_Sequence

Batching sequences with variable lengths when training an RNN is usually cumbersome.

The batch requires all sequences to be of the same length, so usually people simply pad the sequences (e.g. with 0) to the max length.

e.g. 2 sequences [7,8,9] and [1,2,3,4,5]

would be padded into a batch [[7,8,8,0,0], [1,2,3,4,5]]

Now we can send the batch into an RNN, however the padded zeros waste a lot of computation and also may impact the computation as there shouldn't be zeros passing through.

To be able to process only the non-padded part of the sequences, torch.nn provides a way to pack the padded sequence.

What it does is reorganizing the sequences of different lengths in an internal format of two lists.

the first list contains all elements from all the sequences, with elements interleaved by time steps.

  e.g. the above sequences are organized into [7, 1, 8, 2, 8, 3, 4, 5]

       it picks an element from sequence 1 and sequence 2 at each time step.

the second list is the actual batch size at each time step

  e.g. the above sequences's batch size becomes [2, 2, 2, 1, 1]

  that means at time step 1, 2 and 3, there are 2 elements. at time step 4 and 5, there is only 1 element in the batch

  so it knows the batch shrinks, and probably when sequences go through an RNN network, only the adjusted batch of elements are put forward.

  

With the internal representation of the variable-length sequences, RNN optimizes internally the computation.

example script:

import torch

from torch.nn.utils.rnn import pad_sequence

from torch.nn.utils.rnn import pack_padded_sequence

seq1 = torch.tensor([7, 8, 9])

seq2 = torch.tensor([1, 2, 3, 4, 5])

padded = pad_sequence([seq1, seq2])

lengths = torch.tensor([3, 5])

packed = pack_padded_sequence(padded, lengths, enforce_sorted=False)

The sequences need to be sorted by length decreasingly, otherwise set enforce_sorted to false.

The padded sequences (input into the pack function)

tensor([[7, 1],

        [8, 2],

        [9, 3],

        [0, 4],

        [0, 5]])


The packed sequences:

PackedSequence(data=tensor([1, 7, 2, 8, 3, 9, 4, 5]), batch_sizes=tensor([2, 2, 2, 1, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0]))

the packed sequences can be fed into RNN, LSTM directly.