Simple RNN

LSTM back propagation: following the flows of variables

First of all, the summary of this article is: please just download my Power Point slides which I made and be patient, following the equations.

I am not supposed to use so many mathematics when I write articles on Data Science Blog. However using little mathematics when I talk about LSTM backprop is like writing German, never caring about “der,” “die,” “das,” or speaking little English in English classes (which most high school English teachers in Japan do) or writing Japanese without using any Chinese characters (which looks like a terrible handwriting by a drug addict). In short, that is ridiculous. And all the precise equations of LSTM backprop, written on a Blog is not a comfortable thing to see. So basically the whole of this article is an advertisement on my PowerPoint slides, sponsored by DATANOMIQ, and I can just give you some tips to get ready for the most tiresome part of understanding LSTM here.

*This article is the fifth article of “A gentle introduction to the tiresome part of understanding RNN.”

 *In this article “Densely Connected Layers” is written as “DCL,” and “Convolutional Neural Network” as “CNN.”

1. Chain rules

This article is virtually an article on chain rules of differentiation. Even if you have clear understandings on chain rules, I recommend you to take a look at this section. If you have written down all the equations of back propagation of DCL, you would have seen what chain rules are. Even simple chain rules for backprop of normal DCL can be difficult to some people, but when it comes to backprop of LSTM, it is a pure torture.  I think using graphical models would help you understand what chain rules are like. Graphical models are basically used to describe the relations of variables and functions in probabilistic models, so to be exact I am going to use “something like graphical models” in this article. Not that this is a common way to explain chain rules.

First, let’s think about the simplest type of chain rule. Assume that you have a function f=f(x)=f(x(y)), and relations of the functions are displayed as the graphical model at the left side of the figure below. Variables are a type of function, so you should think that every node in graphical models denotes a function. Arrows in purple in the right side of the chart show how information propagate in differentiation.

Next, if you a function f , which has two variances  x_1 and x_2. And both of the variances also share two variances  y_1 and y_2. When you take partial differentiation of f with respect to y_1 or y_2, the formula is a little tricky. Let’s think about how to calculate \frac{\partial f}{\partial y_1}. The variance y_1 propagates to f via x_1 and x_2. In this case the partial differentiation has two terms as below.

In chain rules, you have to think about all the routes where a variance can propagate through. If you generalize chain rules, that is like below, and you need to understand chain rules in this way to understanding any types of back propagation.

The figure above shows that if you calculate partial differentiation of f with respect to y_i, the partial differentiation has n terms in total because y_i propagates to f via n variances. In order to understand backprop of LSTM, you constantly have to care about the flow of variances, which I showed as arrows in purple above.

2. Chain rules in LSTM

I would like you to remember the figure below, which I used in the second article to show how errors propagate backward during backprop of simple RNNs. After forward propagation, first of all, you need to calculate \frac{\partial J}{\partial \boldsymbol{\theta}^{(t)}}, gradients of the error function with respect to parameters, at every time step. But you have to be careful that even though these gradients depend on time steps, the parameters \boldsymbol{\theta} do not depend on time steps.

*As I mentioned in the second article I personally think \frac{\partial J}{\partial \boldsymbol{\theta}^{(t)}} should be rather denoted as (\frac{\partial J}{\partial \boldsymbol{\theta}})^{(t)} because parameters themselves do not depend on time. The textbook by MIT press also partly use the former notation. And you are likely to encounter this type of notation, so I think it is not bad to get ready for both.

The errors at time step (t) propagate backward to all the \boldsymbol{h} ^{(s)}, (s \leq t). Conversely, in order to calculate \frac{\partial J}{\partial \boldsymbol{\theta}^{(t)}} errors flowing from J^{(s)},  (s \geq t). In the chart you need arrows of errors in purple for the gradient in a purple frame, orange arrows for gradients in orange frame, red arrows for gradients in red frame. And you need to sum up \frac{\partial J}{\partial \boldsymbol{\theta}^{(t)}} to calculate \frac{\partial J}{\partial \boldsymbol{\theta}} = \sum_{t}{\frac{\partial J}{\partial \boldsymbol{\theta}^{(t)}}}, and you need this gradient \frac{\partial J}{\partial \boldsymbol{\theta}} to renew parameters, one time.

At an RNN block level, the flows of errors and how to renew parameters are the same in LSTM backprop, but the flow of errors inside each block is much more complicated in LSTM backprop. And in this article and my PowerPoint slides, I use a special notation to denote errors: \delta \star  ^{(t)}= \frac{\partial J^{(t)}}{\partial \star}

* Again, please be careful of what \delta \star  ^{(t)} means. Neurons depend on time steps, but parameters do not depend on time steps. So if \star are neurons,  \delta \star  ^{(t)}= \frac{\partial J}{ \partial \star ^{(t)}}, but when \star are parameters, \delta \star  ^{(t)}= \frac{\partial J^{(t)}}{ \partial \star} should be rather denoted like \delta \star  ^{(t)}= (\frac{\partial J}{ \partial \star ^{(t)}}). In the Space Odyssey paper\boldsymbol{\star} are not used as parameters, but in my PowerPoint slides and some other materials, \boldsymbol{\star} are used also as parameteres.

As I wrote in the last article, you calculate \boldsymbol{f}^{(t)}, \boldsymbol{i}^{(t)}, \boldsymbol{z}^{(t)}, \boldsymbol{o}^{(t)} as below. Unlike the last article, I also added the terms of peephole connections in the equations below, and I also added the variances \bar{\boldsymbol{f}^{(t)}}, \bar{\boldsymbol{i}^{(t)}}, \bar{\boldsymbol{z}^{(t)}}, \bar{\boldsymbol{o}^{(t)}} for convenience.

  • \boldsymbol{\bar{f}}^{(t)}=\boldsymbol{W}_{for} \cdot \boldsymbol{x}^{(t)} + \boldsymbol{R}_{for} \cdot \boldsymbol{y}^{(t-1)} + \boldsymbol{p}_{for}\odot \boldsymbol{c}^{(t-1)} + \boldsymbol{b}_{for}
  • \boldsymbol{\bar{i}}^{(t)}=\boldsymbol{W}_{in} \cdot \boldsymbol{x}^{(t)} + \boldsymbol{R}_{in} \cdot \boldsymbol{y}^{(t-1)} + \boldsymbol{p}_{in}\odot \boldsymbol{c}^{(t-1)} + \boldsymbol{b}_{in}
  • \boldsymbol{\bar{z}}^{(t)}=\boldsymbol{W}_z \cdot \boldsymbol{x}^{(t)} + \boldsymbol{R}_z \cdot \boldsymbol{y}^{(t-1)} + \boldsymbol{b}_z
  • \boldsymbol{\bar{o}}^{(t)}=\boldsymbol{W}_{out} \cdot \boldsymbol{x}^{(t)} + \boldsymbol{R}_{out} \cdot \boldsymbol{y}^{(t-1)} + \boldsymbol{p}_{out}\odot \boldsymbol{c}^{(t)} + \boldsymbol{b}_{out}
  • \boldsymbol{f}^{(t)}=\sigma( \boldsymbol{\bar{f}}^{(t)})
  • \boldsymbol{i}^{(t)}=\sigma(\boldsymbol{\bar{i}}^{(t)})
  • \boldsymbol{z}^{(t)}=tanh(\boldsymbol{\bar{z}}^{(t)})
  • \boldsymbol{o}^{(t)}=\sigma(\boldsymbol{\bar{o}}^{(t)})

With  Hadamar product operator, the renewed cell and the output are calculated as below.

  • \boldsymbol{c}^{(t)} = \boldsymbol{z}^{(t)}\odot \boldsymbol{i}^{(t)} + \boldsymbol{c}^{(t-1)} \odot \boldsymbol{f}^{(t)}
  • \boldsymbol{y}^{(t)} = \boldsymbol{o}^{(t)} \odot tanh(\boldsymbol{c}^{(t)})

In this article I would rather give instructions on how to read my PowerPoint slide. Just as general backprop, you need to calculate gradients of error functions with respect to parameters, such as \delta \boldsymbol{W}_{\star}, \delta \boldsymbol{R}_{\star}, \delta \boldsymbol{p}_{\star}, \delta \boldsymbol{b}_{\star}, where \star is either of \{z, in, for, out \}. And just as backprop of simple RNNs, in order to calculate gradients with respect to parameters, you need to calculate errors of neurons, that is gradients of error functions with respect to neurons, such as \delta \boldsymbol{f}^{(t)}, \delta \boldsymbol{i}^{(t)}, \delta \boldsymbol{z}^{(t)}, \delta \boldsymbol{o}^{(t)}.

*Again and again, keep it in mind that neurons depend on time steps, but parameters do not depend on time steps.

When you calculate gradients with respect to neurons, you can first calculate \delta \boldsymbol{y}^{(t)}, but the equation for this error is the most difficult, so I recommend you to put it aside for now. After calculating \delta \boldsymbol{y}^{(t)}, you can next calculate \delta \bar{\boldsymbol{o}}^{(t)}= \frac{\partial J^{(t)}}{ \partial \bar{\boldsymbol{o}}^{(t)}}. If you see the LSTM block below as a graphical model which I introduced, the information of \bar{\boldsymbol{o}}^{(t)} flow like the purple arrows. That means, \bar{\boldsymbol{o}}^{(t)} affects J only via \boldsymbol{y}^{(t)}, and this structure is equal to the first graphical model which I have introduced above. And if you calculate \bar{\boldsymbol{o}}^{(t)} element-wise, you get the equation \delta \bar{o}_{k}^{(t)}=\frac{\partial J}{\partial \bar{o}_{k}^{(t)}}= \frac{\partial J}{\partial y_{k}^{(t)}} \frac{\partial y_{k}^{(t)}}{\partial \bar{o}_{k}^{(t)}}.

*The k is an index of an element of vectors. If you can calculate element-wise gradients, it is easy to understand that as differentiation of vectors and matrices.

Next you can calculate \delta \boldsymbol{c}^{(t)}, and chain rules are very important in this process. The flow of \delta \boldsymbol{c}^{(t)} to J can be roughly divided into two streams: the one flows to J as \bodlsymbol{y}^{(t)}, and the one flows to J as \bodlsymbol{c}^{(t+1)}. And the stream from \bodlsymbol{c}^{(t)} to \bodlsymbol{y}^{(t)} also have two branches: the one via \bar{\boldsymbol{o}}^{(t)} and the one which directly converges as  \bodlsymbol{y}^{(t)}. Just as well, the stream from \bodlsymbol{c}^{(t)} to \bodlsymbol{c}^{(t+1)} also have three branches: the ones via \bar{\boldsymbol{f}}^{(t)}, \bar{\boldsymbol{i}}^{(t)}, and the one which directly converges as \bodlsymbol{c}^{(t+1)}.

If you see see these flows as graphical a graphical model, that would be like the figure below.

According to this graphical model, you can calculate \delta \boldsymbol{c} ^{(t)} element-wise as below.

* TO BE VERY HONEST I still do not fully understand why we can apply chain rules like above to calculate \delta \boldsymbol{c}^{(t)}. When you apply the formula of chain rules, which I showed in the first section, to this case, you have to be careful of where to apply partial differential operators \frac{\partial}{ \partial c_{k}^{(t)}}. In the case above, in the part \frac{\partial y_{k}^{(t)}}{\partial c_{k}^{(t)}} the partial differential operator only affects tanh(c_{k}^{(t)}) of o_{k}^{(t)} \cdot tanh(c_{k}^{(t)}), and in the part \frac{\partial c_{k}^{(t+1)}}{\partial c_{k}^{(t)}}, the partial differential operator \frac{\partial}{\partial c_{k}^{(t)}} only affects the part c_{k}^{(t)} of the term c^{t}_{k} \cdot f_{k}^{(t+1)}. In the \frac{\partial \bar{o}_{k}^{(t)}}{\partial c_{k}^{(t)}} part, only (p_{out})_{k} \cdot c_{k}^{(t)},  in the \frac{\partial \bar{i}_{k}^{(t+1)}}{\partial c_{k}^{(t)}} part, only (p_{in})_{k} \cdot c_{k}^{(t)}, and in the \frac{\partial \bar{f}_{k}^{(t+1)}}{\partial c_{k}^{(t)}} part, only (p_{in})_{k} \cdot c_{k}^{(t)}. But some other parts, which are not affected by \frac{\partial}{ \partial c_{k}^{(t)}} are also functions of c_{k}^{(t)}. For example o_{k}^{(t)} of o_{k}^{(t)} \cdot tanh(c_{k}^{(t)}) is also a function of c_{k}^{(t)}. And I am still not sure about the logic behind where to affect those partial differential operators.

*But at least, these are the only decent equations for LSTM backprop which I could find, and a frequently cited paper on LSTM uses implementation based on these equations. Computer science is more of practical skills, rather than rigid mathematical logic. It  If you have any comments or advice on this point, please let me know.

Calculating \delta \bar{\boldsymbol{f}}^{(t)}, \delta \bar{\boldsymbol{i}}^{(t)}, \delta \bar{\boldsymbol{z}}^{(t)} are also relatively straigtforward as calculating \delta \bar{\boldsymbol{o}}^{(t)}. They all use the first type of chain rule in the first section. Thereby you can get these gradients: \delta \bar{f}_{k}^{(t)}=\frac{\partial J}{ \partial \bar{f}_{k}^{(t)}} =\frac{\partial J}{\partial c_{k}^{(t)}} \frac{\partial c_{k}^{(t)}}{ \partial \bar{f}_{k}^{(t)}}, \delta \bar{i}_{k}^{(t)}=\frac{\partial J}{\partial \bar{i}_{k}^{(t)}} =\frac{\partial J}{\partial c_{k}^{(t)}} \frac{\partial c_{k}^{(t)}}{ \partial \bar{i}_{k}^{(t)}}, and \delta \bar{z}_{k}^{(t)}=\frac{\partial J}{\partial \bar{z}_{k}^{(t)}} =\frac{\partial J}{\partial c_{k}^{(t)}} \frac{\partial c_{k}^{(t)}}{ \partial \bar{i}_{k}^{(t)}}.

All the gradients which we have calculated use the error \delta \boldsymbol{y}^{(t)}, but when it comes to calculating \delta \boldsymbol{y}^{(t)}….. I can only say “Please be patient. I did my best in my PowerPoint slides to explain that.” It is not a kind of process which I want to explain on Word Press. In conclusion you get an error like this: \delta \boldsymbol{y}^{(t)}=\frac{\partial J^{(t)}}{\partial \boldsymbol{y}^{(t)}} + \boldsymbol{R}_{for}^{T} \delta \bar{\boldsymbol{f}}^{(t+1)} + \boldsymbol{R}_{in}^{T}\delta \bar{\boldsymbol{i}}^{(t+1)} + \boldsymbol{R}_{out}^{T}\delta \bar{\boldsymbol{o}}^{(t+1)} + \boldsymbol{R}_{z}^{T}\delta \bar{\boldsymbol{z}}^{(t+1)}, and the flows of \boldsymbol{y}^{(t)} are as blow.

Combining the gradients we have got so far, we can calculate gradients with respect to parameters. For concrete results, please check the Space Odyssey paper or my PowerPoint slide.

3. How LSTMs tackle exploding/vanishing gradients problems

*If you are allergic to mathematics, you should not read this section or download my PowerPoint slide.

*Part of this section is more or less subjective, so if you really want to know how LSTM mitigate the problems, I highly recommend you to also refer to other materials. But at least I did my best for this article.

LSTMs do not completely solve, vanishing gradient problems. They mitigate vanishing/exploding gradient problems. I am going to roughly explain why they can tackle those problems. I think you find many explanations on that topic, but many of them seems to have some mathematical mistakes (even the slide used in a lecture in Stanford University) and I could not partly agree with some statements. I also could not find any papers or materials which show the whole picture of how LSTMs can tackle those problems. So in this article I am only going to give instructions on the most mainstream way to explain this topic.

First let’s see how gradients actually “vanish” or “explode” in simple RNNs. As I in the second article of this series, simple RNNs propagate forward as the equations below.

  • \boldsymbol{a}^{(t)} = \boldsymbol{b} + \boldsymbol{W} \cdot \boldsymbol{h}^{(t-1)} + \boldsymbol{U} \cdot \boldsymbol{x}^{(t)}
  • \boldsymbol{h}^{(t)}= g(\boldsymbol{a}^{(t)})
  • \boldsymbol{o}^{(t)} = \boldsymbol{c} + \boldsymbol{V} \cdot \boldsymbol{h}^{(t)}
  • \hat{\boldsymbol{y}} ^{(t)} = f(\boldsymbol{o}^{(t)})

And every time step, you get an error function J^{(t)}. Let’s consider the gradient of J^{(t)} with respect to \boldsymbol{h}^{(k)}, that is the error flowing from J^{(t)} to \boldsymbol{h}^{(k)}. This error is the most used to calculate gradients of the parameters.

If you calculate this error more concretely, \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(k)}} = \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} \frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{h}^{(t-1)}} \cdots \frac{\partial \boldsymbol{h}^{(k+2)}}{\partial \boldsymbol{h}^{(k+1)}} \frac{\partial \boldsymbol{h}^{(k+1)}}{\partial \boldsymbol{h}^{(k)}} = \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} \prod_{k< s \leq t} \frac{\partial \boldsymbol{h}^{(s)}}{\partial \boldsymbol{h}^{(s-1)}}, where \frac{\partial \boldsymbol{h}^{(s)}}{\partial \boldsymbol{h}^{(s-1)}} = \boldsymbol{W} ^T \cdot diag(g'(\boldsymbol{b} + \boldsymbol{W}\cdot \boldsymbol{h}^{(s-1)} + \boldsymbol{U}\cdot \boldsymbol{x}^{(s)})) = \boldsymbol{W} ^T \cdot diag(g'(\boldsymbol{a}^{(s)})).

* If you see the figure as a type of graphical model, you should be able to understand the why chain rules can be applied as the equation above.

*According to this paper \frac{\partial \boldsymbol{h}^{(s)}}{\partial \boldsymbol{h}^{(s-1)}}  = \boldsymbol{W} ^T \cdot diag(g'(\boldsymbol{a}^{(s)})), but it seems that many study materials and web sites are mistaken in this point.

Hence \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(k)}} = \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} \prod_{k< s \leq t} \boldsymbol{W} ^T \cdot diag(g'(\boldsymbol{a}^{(s)})) = \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} (\boldsymbol{W} ^T )^{(t - k)} \prod_{k< s \leq t} diag(g'(\boldsymbol{a}^{(s)})). If you take norms of the members you get an equality \left\lVert \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(k)}} \right\rVert \leq \left\lVert \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(t)}} \right\rVert \left\lVert \boldsymbol{W} ^T \right\rVert ^{(t - k)} \prod_{k< s \leq t} \left\lVert diag(g'(\boldsymbol{a}^{(s)}))\right\rVert. I will not go into detail anymore, but it is known that according to this inequality, multiplication of weight vectors exponentially converge to 0 or to infinite number.

We have seen that the error \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(k)}} is the main factor causing vanishing/exploding gradient problems. In case of LSTM, \frac{\partial J^{(t)}}{\partial \boldsymbol{c}^{(k)}} is an equivalent. For simplicity, let’s calculate only \frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{c}^{(t-1)}}, which is equivalent to \frac{\partial \boldsymbol{h}^{(t)}}{\partial \boldsymbol{h}^{(t-1)}} of simple RNN backprop.

* Just as I noted above, you have to be careful of which part the partial differential operator \frac{\partial}{\partial \boldsymbol{c}^{(t-1)}} affects in the chain rule above. That is, you need to calculate \frac{\partial}{\partial \boldsymbol{c}^{(t-1)}} (\boldsymbol{c}^{(t-1)} \odot \boldsymbol{f}^{(t)}), and the partial differential operator only affects \boldsymbol{c}^{(t-1)}. I think this is not a correct mathematical notation, but please forgive me for doing this for convenience.

If you continue calculating the equation above more concretely, you get the equation below.

I cannot mathematically explain why, but it is known that this characteristic of gradients of LSTM backprop mitigate the vanishing/exploding gradient problem. We have seen that if you take a norm of \frac{\partial J^{(t)}}{\partial \boldsymbol{h}^{(k)}}, that is equal or smaller than repeated multiplication of the norm of the same weight matrix, and that soon leads to vanishing/exploding gradient problem. But according to the equation above, even if you take a norm of repeatedly multiplied \frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{c}^{(t-1)}}, its norm cannot be evaluated with a simple value like repeated multiplication of the norm of the same weight matrix. The outputs of each gate are different from time steps to time steps, and that adjust the value of \frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{c}^{(t-1)}}.

*I personally guess the item diag(\boldsymbol{f}^{(t)}) is every effective. The unaffected value of can directly diag(\boldsymbol{f}^{(t)}) adjust the value of \frac{\partial \boldsymbol{c}^{(t)}}{\partial \boldsymbol{c}^{(t-1)}}. And as a matter of fact, it is known that performances of LSTM drop the most when you gite rid of forget gates.

When it comes to tackling exploding gradient problems, there is a much easier technique called gradient clipping. This algorithm is very simple: you just have to adjust the size of gradient so that the absolute value of gradient is under a threshold at every time step. Imagine that you decide in which direction to move by calculating gradients, but when the footstep is going to be too big, you just adjust the size of footstep to the threshold size you have set. In a pseudo code, write a gradient clipping part only with two line code as below.

*\boldsymbol{g} is a gradient at the time step threshold is the maximum size of the “step.”

The figure below, cited from a deep learning text from MIT press textbook, is a good and straightforward explanation on gradient clipping.It is known that a strongly nonlinear function, such as error functions of RNN, can have very steep or plain areas. If you artificially visualize the idea in 3-dimensional space, as the surface of a loss function J with two variants w, b, that means the loss function J has plain areas and very steep cliffs like in the figure.Without gradient clipping, at the left side, you can see that the black dot all of a sudden climb the cliff and could jump to an unexpected area. But with gradient clipping, you avoid such “big jumps” on error functions.

Source: Source: Goodfellow and Yoshua Bengio and Aaron Courville, Deep Learning, (2016), MIT Press, 409p

 

I am glad that I have finally finished this article series. I am not sure how many of the readers would have read through all of the articles in this series, including my PowerPoint slides. I bet that is not so many. I spent a great deal of my time for making these contents, but sadly even when I was studying LSTM, it was becoming old-fashioned, at least in natural language processing (NLP) field: a very promising algorithm named Transformer has been replacing the position of LSTM. Deep learning is a very fast changing field. I also would like to make illustrative introductions on attention mechanism in NLP, from seq2seq model to Transformer. And I think LSTM would still remain as one of the algorithms in sequence data processing, such as hidden Hidden Markov model, or particle filter.

About Author

0 replies

Leave a Reply

Want to join the discussion?
Feel free to contribute!

Leave a Reply

Your email address will not be published. Required fields are marked *

1058 Views