Recurrent Neural Networks Beyond Back-Propagation
Recurrent Neural Networks Beyond Back-Propagation
Artificial intelligence, especially deep learning has grown tremendously in the recent years, impacting our lives more closely than ever. This revolution was made possible by many ground-breaking algorithms that made the realisation of various neural networks possible. This section will delve into the background behind neural networks, recurrent neural networks, and backpropagation through time (BPTT).
1.1 Backpropagation
Backpropagation [5, 14, 12] was one such algorithm that has now become the de-facto method to train any networks with ease. The algorithm has matured over the years and gives the peak performance compared to any other algorithms. However, the rising demands for computations with lower energy footprint forces us to model a network that mimics our human brain [9]. Backpropagation, on the other hand, is biologically implausible and cannot be used to model a human brain for the following reasons[11, 9, 16]:
• It requires precise knowledge of the non-linearity in the corresponding forward path, but in biological learning, the feedback path is usually through different set of neurons which wont have the same value as that of the forward path.
• It is mostly linear and non-linearity is introduced only at activation functions, but neurons interleave linear and non-linear operations quite often.
• The loss is propagated through the same path as the forward path having the same weights, but in our brain, the feedback path has been found not to work in this fashion.
• It uses real values to convey information to layers but neurons communicate through spike bursts.
• It requires to switch between forward and backward propagation at fixed intervals.
1.2 Recurrent Neural Networks
Recurrent Neural Network [1] is a type of neural network that has the capability to process sequential data. Modern voice assistant applications like Siri and Alexa, language translation models, speech to text conversion models, all use RNN as their backbone to detect and output sequential data. It does this by retaining the states of the previous time instances (t=0) and passing them as additional input along with inputs of that time instant (t=1). It retains the internal state, by feeding back the output of the hidden state as the input for the next time-step. Figure 1(a) depicts the structure of basic RNN cell. An RNN takes a sequence of inputs at consecutive time steps {x0,x1,...,xt} and may either produce an output at each time step {y0,y1,...,yt}. or produce one output y for the whole sequence based on the type of problem (sequence-to-sequence learning or classification problem). Figure 1(b) depicts the structure of an RNN cell unrolled in time with a sequence length of 3. The hidden state is calculated according to the following equation:
ht = tanh(Wxh · xt + Whh · ht−1 + b)
where Wxh is the input weight matrix, xt is the input at time t, Whh is the hidden wight matrix, ht−1) is the hidden state of previous time instant, b is the bias. Since the hidden state will be multiplied continuously along the sequence length, the activation function is often set to be sigmoid or tanh. Similarly the output y is calculated using the following equation: y = softmax(Why · htmax + c)
where Why is the output weight matrix and c is output bias. If RNN is used for classification, softmax is chosen as the activation function. The softmax function will result in the output values summing up to 1, with the largest value indicating the classification.
1.3 Back Propagation Through Time (BPTT)
Training RNN involves updating all three weight matrices – Wxh, Whh and Why. It is trained using back-propagation though a method called Back-Propagation through time. Since the hidden state from previous instant is given as input to the next instant, the error gradient has to be back propagated to the previous time instance as well. We start from the final time instance, apply backpropagation to find gradient matrices at instant t, propagate it to the previous instant t-1 and continue till time instant t=0. After completing the propagation, the gradient matrices at various time instants are accumulated and the weight matrices are updated. This is depicted in Figure 1(c).
1.4 Issues with Backpropagation
As the gradients are propagated throughout the sequence length, the repetitive multiplication might lead to:
• Very small gradients if the values at each time instant is < 1, which is known as vanishing gradients
• Very large gradients if the values at each time instant is > 1, which is known as exploding gradients
This imposes a restriction on the sequence length that can be processed by the model. It makes RNN unable to resolve long-term dependencies and restricts it to short term applications only.
1.5 Target propagation
To mitigate the biological implausibility, vanishing and exploding gradients problem in RNN, we propose Target Propagation as the alternative training algorithm. Target propagation [2, 6] propagates a target value from the output layer to each of the hidden layers rather than a loss gradient. If each of the layers match their targets, the overall output would match the expected output. The loss between the layer’s output and its target is used to locally update its weights. The error correction computations are performed within a single layer. The target for nth layer is generated by taking an inverse of the target of (n+1) layer using the inverse function G. The weight matrix Vhh is also trained in parallel to correctly predict the inverse of the corresponding forward path. A later section will delve more into DTP, and its application in a RNN.
Figure 1: a) A simple RNN cell with one input, one hidden unit and one output. b) An RNN cell with one layer is unrolled for 3 time steps. c) Back-propagation through time unrolled for 3 time steps.
Schmit et al. [15] gave an overview of the various RNN architectures and their use cases. It proposes different methods by which the exploding gradients problems can be mitigated. One of them is gradient clipping where the gradients are clipped to a certain range whenever it exceeds it. This is only a temporary solution since we end up losing valuable information due to clipping.
Long-Short Term Memory is a type of RNN architecture that address the issue of long input lengths by using gated cells. The input to a cell is the current time step information plus a combination of the past input data with a gating cell that controls how much memory (past data) one wants to let into this cell to process. By gating how much previous information is needed, one can avoid the extremely long sequences of inputs that can cause the exploding or vanishing gradients
Target propagation propagates a target from a layer to its previous layer using an inverse function. [2] proposes to use an auto-encoder to provide an approximate inverse function since finding the actual inverse would be computationally intensive. Difference target propagation [7] aims to reduce the imperfections imposed by using auto-encoders by propagating the difference of the inverse of actual and the inverse of the target value of the (n+1) layer to nth layer.
The work by researches [10] alleviates the vanishing gradients problem by using target propagation to train RNN. Vincent et all [13] proposes an algorithm similar to DTP but with slight modification. They define regularized inverses through a variational formulation and obtain approximate inverses via these regularized inverses. To approximate the inverse of the forward computations a parameterized layer needs to be learned to get a good approximation which involves numerically solving an optimization problem for each layer. These optimization problems come with a computational cost that can be better controlled by using regularized inversions presented earlier.
Another idea closely related to the target propagation approach is synthetic gradients [4]. The concept of synthetic gradients is similar to target propagation in the sense that the synaptic weights update process does not have a strict dependency on a backpropagated global error signal. The different layers of the network are individually updated, but instead of providing local targets and updating the network parameters to bring the activation values close to targets, the network uses local models to approximate the true error gradients directly. Under this scheme, the model Mt for network layer t is trained by minimising the error between the predicted gradient ∂ˆt, and the gradient estimated by the synthetic model in the next layer ∂ˆt+1. This process is repeated for all upstream layers until the final layer of the network, where the target gradient can be computed directly from the global error E. This training method allows individual segments of the network to be update asynchronously, resulting in an architecture known as Decoupled Neural Interfaces (DNIs).
Figure 2: RNN with TPTT with three input sequences (one -hidden layer)
To implement RNN, we used Difference Target Propagation, which was first proposed by Lee et al. [8]. This is a form of target propagation [6, 3], where the network estimates targets for each of the time-steps and propagate the targets instead of gradients. At each time-step, the activation value is compared with the estimated target based on a local loss function and the weights are updated. The targets are estimated in a way that, even if the optimization is done locally, the global loss converges when the activation values approach the targets. In this section, we are going to discuss the mathematical computations that’s going on inside this RNN model, specially how the targets are estimated in this model for each of the time-steps in RNN. The ‘unrolled’ RNN is shown in Fig 2. Implementation of recurrent neural network with target propagation was coined "TPTT" or "Target Propagation Through Time" by Manchev et. al.[10], since targets are propagated through the time-steps in this network.
The feed-forward pass of TPTT is same as BPTT. Then, for credit assignment, instead of backpropagating the gradient of the global loss, we set the the target of the last layer using that gradient: y ̂= y- η * ∂L(y,Y)/∂y
where η is usually a small step size. This is to note that if we use the mean squared error (MSE) as the global loss and η = 0.5, we get yˆ= Y . That means, the labels of the dataset becomes the target for the final layer.
Then it’s time to estimate the local targets of hidden layers, hˆt for the neurons in the hidden layer at time t. Setting the target for the final time step hˆtmax is based on the gradient of the error with respect to the activations of, tmax: h ̂tmax = htmax - αi * (∂L(y,Y))/∂htmax
where αi is an initial learning rate. This is to note here that, TPTT has three different learning rates. One is this one, which is used to estimate the local target of the final time-step. Other two are for updating weight matrices used in the feed-forward pass and weights for estimating the targets respectively.
An inverse function of the forward output function is used to estimate the targets for the earlier time steps. If F(·) is a function that computes the hidden state of the network at time t, then F(·) is a function of xt and ht−1:
ht = F(xt,ht−1) = σ(Wxh · xt + Whh · ht−1 + bh)
The inverse of F(xt,ht−1) should be a function G(·) that takes xt and ht as inputs and produces an approximation of ht−1:
ht−1 ≈ G(xt,F(xt,ht−1)) ≈ G(xt,ht)
If function G(·) can be approximated, then it can serve to set the local targets using:
hˆt = G(xt+1,hˆt+1)
The presented model adopts a linearly corrected formula (difference target propagation) suggested by Lee et al. [8], which stabilizes the optimization problem when G(·) is not a perfect inverse of F(·):
hˆt = ht − G(xt+1,ht+1)+ G(xt+1,hˆt+1)
If G(·) is an inverse of F(·) then G(xt+1,ht+1) = ht and ht − G(xt+1,ht+1) = ht − ht = 0, thus this equation reduces to the previous one. The corrected formula stabilizes the optimization as it guarantees that as ht+1 approaches hˆt+1, ht also approaches hˆt.
The proposed configuration for G(·) is
G(xt+1,ht+1) = σ(Wxh · xt+1 + Vhh · ht+1 + ch)
where Vhh is a matrix of weights and ch is a bias term, which the network must learn so that (4) holds. Plugging this into (6) produces the final equation for the upstream targets:
hˆt = ht − σ(Wxh · xt+1 + Vhh · ht+1 + ch)+ σ(Wxh · xt+1 + Vhh · hˆt+1 + ch)
The network functions by alternating between two stages after the local targets have been specified. It uses gradient descent to update the parameters of G(·) first, and then it updates the feed-forward parameters. The following loss function is used as a local loss to update the inverse-mapping matrices-
Linv = ( ||G(xt,F(xt,ht−1+ε))-( ht−1+ε)||2 )2
Here, noise is injected because we do not want to estimate an inverse mapping only for the concrete values we see in training but for a region around the these values to facilitate the computation for data which the model has not seen before. On the other hand, the feed-forward matrices, Whh, Wxh and Why are updated based on the local the local loss function-
L = ( ||F(xt,ht−1) -hˆt ||2 ) 2
The Target-estimation, G(·) parameter optimization, and F(·) parameter optimization are repeated until a chosen convergence criterion is satisfied.
[1] Abdul Manan Ahmad, Saliza Ismail, and DF Samaon. “Recurrent neural network with backpropagation through time for speech recognition”. In: IEEE International Symposium on Communications and Information Technology, 2004. ISCIT 2004. Vol. 1. IEEE. 2004, pp. 98– 102.
[2] Yoshua Bengio. “How auto-encoders could provide credit assignment in deep networks via target propagation”. In: arXiv preprint arXiv:1407.7906 (2014).
[3] Yoshua Bengio et al. “Towards biologically plausible deep learning”. In: arXiv preprint arXiv:1502.04156 (2015).
[4] Max Jaderberg et al. “Decoupled neural interfaces using synthetic gradients”. In: International conference on machine learning. PMLR. 2017, pp. 1627–1635.
[5] Yann Le Cun and Françoise Fogelman-Soulié. “Modèles connexionnistes de l’apprentissage”. In: Intellectica 2.1 (1987), pp. 114–143.
[6] Yann LeCun, Yoshua Bengio, and Geoffrey Hinton. “Deep learning”. In: nature 521.7553 (2015), pp. 436–444.
[7] Dong-Hyun Lee et al. “Difference target propagation”. In: Machine Learning and Knowledge Discovery in Databases: European Conference, ECML PKDD 2015, Porto, Portugal, September 7-11, 2015, Proceedings, Part I 15. Springer. 2015, pp. 498–515.
[8] Dong-Hyun Lee et al. “Difference target propagation”. In: Machine Learning and Knowledge Discovery in Databases: European Conference, ECML PKDD 2015, Porto, Portugal, September 7-11, 2015, Proceedings, Part I 15. Springer. 2015, pp. 498–515.
[9] Timothy P Lillicrap et al. “Backpropagation and the brain”. In: Nature Reviews Neuroscience
21.6 (2020), pp. 335–346.
[10] Nikolay Manchev and Michael Spratling. “Target propagation in recurrent neural networks”. In: Journal of Machine Learning Research 21.7 (2020), pp. 1–33.
[11] Pieter R. Roelfsema and Arjen van Ooyen. “Attention-Gated Reinforcement Learning of Internal Representations for Classification”. In: Neural Computation 17.10 (Oct. 2005), pp. 2176–
2214. ISSN: 0899-7667. DOI: 10.1162/0899766054615699. eprint: https://direct. mit.edu/neco/article-pdf/17/10/2176/1059776/0899766054615699.pdf. URL: https://doi.org/10.1162/0899766054615699.
[12] Raul Rojas and Raúl Rojas. “The backpropagation algorithm”. In: Neural networks: a systematic introduction (1996), pp. 149–182.
[13] Vincent Roulet and Zaid Harchaoui. “Target propagation via regularized inversion”. In: arXiv preprint arXiv:2112.01453 (2021).
[14] David E Rumelhart, Geoffrey E Hinton, and Ronald J Williams. “Learning representations by back-propagating errors”. In: nature 323.6088 (1986), pp. 533–536.
[15] Robin M Schmidt. “Recurrent neural networks (rnns): A gentle introduction and overview”. In: arXiv preprint arXiv:1912.05911 (2019).
[16] Yuhang Song et al. “Can the brain do backpropagation?—exact implementation of backpropagation in predictive coding networks”. In: Advances in neural information processing systems 33 (2020), pp. 22566–22579.