Long short Term memory LSTM
LSTM is developed from Recurrent Neural Network RNN.
A naive RNN keeps hidden state h (i.e. memory) as in below.
x is the input for current time step t, h is the hidden state from the last time step t-1. So x (current input) and h (memory) collectively determine the current hidden state (output of memory) h'.
h' is simply calculated by combining Wh and Wx, then activated by sigmoid function.
h' also mapped linearly to y as the prediction for x. The Wh' maps to an output layer, e.g. with 2 neurons ( yes / no), and softmax is further applied to converted the output to probability.
As x1, x2, x3 ... are in a time series, the whole process if unfolded is like the below. The memory h0 (none), h1, h2 ... pass on to next time step, while outputting the prediction for each step y1, y2, y3 ...
The function f which is comprised of weights Wh, Wi and Wo will be trained gradually. i is for information / input. o is for output.
LSTM is initially designed to solve the vanishing / exploding gradient problem. Comparing to RNN, LSTM works better for longer series of data by introducing long-term memory / cell state.
Compared to the naive RNN that has only hidden state (short term memory), LSTM introduces cell state (Ct), the long-term memory.
The hidden state h usually changes fast but the cell state changes slowly. The key concept of cell state /long-term memory is that it can choose to FORGET some information and also MEMORIZE some extra information.
Before diving into the calculation of Ct, the following are some basic data tranformations from Xt.
Firstly we concatenate X(t) and h(t-1) which are the original input at time t and the hidden state at t-1.
Then we times the concatenation with different weight matrices W, Wi, Wf and Wo. The multiplication results are for different purposes.
The Z serves as a transformed data input, same as the Z value in a feedforward network. It uses tanh here because it needs to map data to -1 to 1 so the input data Z has both negative and positive values. (sigmoid would provide positive values only)
The Zi (i for information) denotes the extra information it wants to add to the long-term memory. It uses sigmoid here to simulate a logical gate (0-1) but it's analog so easier to calculate gradient.
The Zf (f for forget) denotes the information it wants to forget from the long-term memory.
The Zo (o for output) denotes the information it uses for predicting the current y, i.e. output for the current time step.
Here comes the detailed implementation of LSTM.
Starting from the bottom, we concatenate X(t) and the previous hidden state h(t-1) and apply the mentioned data transformations to get Z, Zf, Zi and Zo.
The
is the Hadamard Product, which is element wise multiplication.
The
is element wise addition.
Firstly the Zf is element-wise multiplied with the previous cell state/long-term memory c(t-1) to forget some information. I.e. the zero or near zero elements in Zf will reduce the corresponding cells in c(t-1).
Then the Z and Zi are multiplied so as to pick the information to add to long-term memory. The one or near one elements in Zi will pass on the corresponding elements in Z. The multiplication result is further added to the previous c(t-1) after forgetting Zf. So essentially c(t-1) forgets information picked by Zf and memorizes information picked by Zi.
This yields the cell state c(t), i.e. current long-term memory.
The cell state c(t) is then transformed by tanh to extend to -1 and 1 to serve as the input for calculating hidden state h(t). Note input value needs to allow both negative and positive values so tanh is used in those scenarios.
The tanh(c(t)) is multiplied with Zo which picks the information to output, and that yields h(t)
The prediction y(t) is simply another linear transformation from h(t) by W' and activated by sigmoid / any activation function.
Here is a more formal formulas form PyTorch documentation. The gt here is the Z calculated above.
A summary:
the long term memory ct is operated through 'forget' and 'remember'. The 'remember' needs to take in input values so a tanh() is used in the flow.
the short term memory ht is derived on top of ct. It expands ct to an input value by tanh() and determines the output using the output gate W^o.
the prediction yt is simply calculated from ht.
Note: Compared with LSTM, a simplified way (less computationally expensive) is GRU Gated Recurrent Unit.
Note for LSTM batch training
A batch of sequences feed into lstm in one ago to finish one round of training.
A very long sequence x1, x2, x3 ..... xn is divided into batches.
Assume batch size is k, i.e. k sequences per batch. And the sequence length is m, i.e. look back m steps. Then the first batch would include the following sequences.
sequence 1 = x1, x2, .. xm
sequence 2 = x2, x3, ... x(m+1)
...
sequence k = xk, x(k+1), ... x(k + m -1)
The second batch would be:
sequence k+1 = x(k+1) ...
sequence k+2 = x(k+2) ...
Each hidden state keeps a number of features/neurons, which is the hidden size. Usually you need more features in hidden state than #features in input xt to memorize stuff. Similar to mapping input neurons to many more neurons in the hidden layers of a feedforward network.
Each sequence here has it's own hidden state, because it processes through one sequence and accumulates memory in hidden state separately.
So k sequences (batch size k) require k hidden states. The overall neurons/features required is k * hidden_size.
After a batch finishes, the hidden state and cell state would have been updated accordingly to the processed sequence.
output, (hn, cn) = lstm(input)
where output has all the hidden states h1, h2, h3... and (hn, cn) are the last hidden state and last cell state.
When kick off the next batch, in some versions, lstm could use the last hidden state from previous batch as the start of the hidden state h0, c0
So the training of the next batch is based on non-zero hidden states which is wrong. In many posts you would see a line of code self.hidden = init_hidden() to initialize the last hidden state before feeding in the next batch.
However, in PyTorch 1.3, looks like the LSTM module defaults to use zero hidden states if not specified. So it probably don't need to init_hidden at every batch.
Check source code in torch.nn.modules.rnn for class LSTM
...
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (zeros, zeros)
...
The hx (i.e. hidden states and cell states) are reset to zeros if hx is not specified. So it shouldn't need to init every time.
There can be epoches of training which simply re-run through all the batches again to further refine the weights.
This is can be done by replicating batches in the training as well.
Shuffle training data
In general, when you shuffle the training data (a set of sequences), you shuffle the order in which sequences are fed to the RNN, you don't shuffle the ordering within individual sequences.
Here it assumes your network is stateless:
The network's memory only persists for the duration of a sequence. Training on sequence B before sequence A doesn't matter because the network's memory state does not persist across sequences.
On the other hand, you don't want to shuffle your sequences if it's stateful:
The network's memory persists across sequences. Sequence A will have an impact on sequence B so A should be fed to the network before sequence B. In this way the network evaluates sequence B with memory of what was in sequence A. A shuffle will destroy that ordering.
Stacked LSTM
Usually one layer of LSTM unit is able to work Ok, but sometimes you may want to add multiple layers of LSTM units.
In that case, the input xt of a layer (layer>1) is the ht from the previous layer. If dropout is introduced, the xt is the ht times delta which is 0 at the probability of dropout_rate.
The last ht will be transformed to yt.
In this way, it can construct deep LSTM with many layers stacked. The benefit is the same as deep CNN that could capture features at different granular and model more complex relationships.
Every extra layer needs to keep its own sets of cell states and hidden states, so the memory usage increases as layers increase.
Bidirectional LSTM
When it comes to NLP, the prediction of yt also depends on the input xt+1, etc. because the context is important in those scenarios.
e.g. Tom had a __?__, the ? could be meal / lecture / etc.
Given a context, Tom had a ? in XYZ cafe, the ? is probably meal.
The bidirectional LSTM is basically two LSTMs, one runs in the original sequence and the second one runs in the reversed sequence.
The forward learning provides cell and hidden states of previous inputs. The backward learning provides cell and hidden states of inputs after.
Then the hidden states from both forward and backward learning transform to the prediction y. As both information from previous and later states is used, the prediction is better for scenarios like NLP.
Obviously, stacking bidirecitonal lstm will give multi-layer bidirectional lstm.