Inspecting gradient magnitudes in context can be a powerful tool to see when recurrent units use short-term or long-term contextual understanding.
Memorization in Recurrent Neural Networks (RNNs) continues to pose a challenge in many applications. We’d like RNNs to be able to store information over many timesteps and retrieve it when it becomes relevant — but vanilla RNNs often struggle to do this.
Several network architectures have been proposed to tackle aspects of this problem, such
as Long-Short-Term Memory (LSTM)
To compare a recurrent unit against its alternatives, both past and recent
papers, such as the Nested LSTM paper by Monzi et al.
While quantitative comparisons are useful, they only provide partial insight into the how a recurrent unit memorizes. A model can, for example, achieve high accuracy and cross entropy loss by just providing highly accurate predictions in cases that only require short-term memorization, while being inaccurate at predictions that require long-term memorization. For example, when autocompleting words in a sentence, a model with only short-term understanding could still exhibit high accuracy completing the ends of words once most of the characters are present. However, without longer term contextual understanding it won’t be able to predict words when only a few characters are known.
This article presents a qualitative visualization method for comparing recurrent units with regards to memorization and contextual understanding. The method is applied to the three recurrent units mentioned above: Nested LSTMs, LSTMs, and GRUs.
The networks that will be analyzed all use a simple RNN structure:
In theory, the time dependency allows it in each iteration to know
about every part of the sequence that came before. However, this time
dependency typically causes a vanishing gradient problem that results in
long-term dependencies being ignored during training
Several solutions to the vanishing gradient problem have been proposed over
the years. The most popular are the aforementioned LSTM and GRU units, but this
is still an area of active research. Both LSTM and GRU are well known
and
thoroughly explained in literature. Recently, Nested LSTMs have also been proposed
It is not entirely clear why one recurrent unit performs better than another in some applications, while in other applications it is another type of recurrent unit that performs better. Theoretically they all solve the vanishing gradient problem, but in practice their performance is highly application dependent.
Understanding why these differences occur is likely an opaque and challenging problem. The purpose of this article is to demonstrate a visualization technique that can better highlight what these differences are. Hopefully, such an understanding can lead to a deeper understanding.
Comparing different Recurrent Units is often more involved than simply comparing the accuracy or cross entropy loss. Differences in these high-level quantitative measures can have many explanations and may only be because of some small improvement in predictions that only requires short-term contextual understanding, while it is often the long-term contextual understanding that is of interest.
Therefore a good problem for qualitatively analyzing contextual
understanding should have a human-interpretable output and depend both on
long-term and short-term contextual understanding. The typical problems
that are often used, such as Penn Treebank
To this end, this article studies the autocomplete problem. Each character is mapped to a target that represents the entire word. The space leading up to the word should also map to that target. This prediction based on the space character is in particular useful for showing contextual understanding.
The autocomplete problem is quite similar to the text8 generation problem: the only difference is that instead of predicting the next letter, the model predicts an entire word. This makes the output much more interpretable. Finally, because of its close relation to text8 generation, existing literature on text8 generation is relevant and comparable, in the sense that models that work well on text8 generation should work well on the autocomplete problem.
The autocomplete dataset is constructed from the full text8 dataset. The recurrent neural networks used to solve the problem have two layers, each with 600 units. There are three models, using GRU, LSTM, and Nested LSTM. See the appendix for more details.
In the recently published Nested LSTM paper
This visualization was inspired by Karpathy et al.
Instead, to get a better idea of how well each model memorizes and uses memory for contextual understanding, the connectivity between the desired output and the input is analyzed. This is calculated as:
Exploring the connectivity gives a surprising amount of insight into the different models’ ability for long-term contextual understanding. Try and interact with the figure below yourself to see what information the different models use for their predictions.
Let’s highlight three specific situations:
These observations show that the connectivity visualization is a powerful tool for comparing models in terms of which previous inputs they use for contextual understanding. However, it is only possible to compare models on the same dataset, and on a specific example. As such, while these observations may show that Nested LSTM is not very capable of long-term contextual understanding in this example; these observations may not generalize to other datasets or hyperparameters.
From the above observations it appears that short-term contextual understanding often involves the word that is being predicted itself. That is, the models switch to using previously seen letters from the word itself, as more letters become available. In contrast, at the beginning of predicting a word, models — especially the GRU network — use previously seen words as context for the prediction.
This observation suggests a quantitative metric: measure the accuracy given how many letters from the word being predicted are already known. It is not clear that this is best quantitative metric: it is highly problem dependent, and it also doesn’t summarize the model to a single number, which one may wish for a more direct comparison.
These results suggest that the GRU model is better at long-term contextual understanding, while the LSTM model is better at short-term contextual understanding. These observations are valuable, as it justifies why the overall accuracy of the GRU and LSTM models are almost identical, while the connectivity visualization shows that the GRU model is far better at long-term contextual understanding.
While more detailed quantitative metrics like this provides new insight, qualitative analysis like the connectivity figure presented in this article still has great value. As the connectivity visualization gives an intuitive understanding of how the model works, which a quantitative metric cannot. It also shows that a wrong prediction can still be considered a useful prediction, such as a synonym or a contextually reasonable prediction.
Looking at overall accuracy and cross entropy loss in itself is not that interesting. Different models may prioritize either long-term or short-term contextual understanding, while both models can have similar accuracy and cross entropy.
A qualitative analysis, where one looks at how previous input is used in the prediction is therefore also important when judging models. In this case, the connectivity visualization together with the autocomplete predictions, reveals that the GRU model is much more capable of long-term contextual understanding, compared to LSTM and Nested LSTM. In the case of LSTM, the difference is much higher than one would guess from just looking at the overall accuracy and cross entropy loss alone. This observation is not that interesting in itself as it is likely very dependent on the hyperparameters, and the specific application.
Much more valuable is that this visualization method makes it possible to intuitively understand how the models are different, to a much higher degree than just looking at accuracy and cross entropy. For this application, it is clear that the GRU model uses repeating words and semantic meaning of past words to make its prediction, to a much higher degree than the LSTM and Nested LSTM models. This is both a valuable insight when choosing the final model, but also essential knowledge when developing better models in the future.
Many thanks to the authors of the original Nested LSTM paper
I am also grateful for the excellent feedback and patience from the Distill team, especially Christopher Olah and Ludwig Schubert, as well as the feedback from the peer-reviewers. Their feedback has dramatically improved the quality of this article.
Review 1 - Abhinav Sharma
Review 2 - Dylan Cashman
Review 3 - Ruth Fong
The Nested LSTM unit attempt to solve the long-term memorization from a
more practical point of view. Where the standard LSTM unit solves the
vanishing gradient problem by adding internal memory, and the GRU attempt
to be a faster solution than LSTM by using no internal memory, the Nested
LSTM goes in the opposite direction of GRU by adding additional memory to
the unit
The additional memory is integrated into the LSTM unit by changing how the
cell value
The complete set of equations then becomes:
Like in vanilla LSTM, the gate activation functions
The abstraction, of how to combine the input with the cell value, allows
for a lot of flexibility. Using this abstraction, it is not only possible
to add one extra internal memory state but the internal
The equations defining
The gate activation functions
The autocomplete dataset is constructed from the full text8 dataset, where each observation consists of maximum 200 characters and is ensured not to contain partial words. 90% of the observations are used for training, 5% for validation and 5% for testing.
The input vocabulary is a-z, space, and a padding symbol. The output
vocabulary consists of the
The GRU, LSTM each have 2 layers of 600 units. Similarly, the Nested LSTM model has 1 layer of 600 units but with 2 internal memory states. Additionally, each model has an input embedding layer and a final dense layer to match the vocabulary size.
Model | Units | Layers | Depth | Parameters | ||
---|---|---|---|---|---|---|
Embedding | Recurrent | Dense | ||||
GRU | 600 | 2 | N/A | 16200 | 4323600 | 9847986 |
LSTM | 600 | 2 | N/A | 16200 | 5764800 | 9847986 |
Nested LSTM | 600 | 1 | 2 | 16200 | 5764800 | 9847986 |
There are 456896 sequences in the training dataset, and a mini-batch size of 64 observations is used. A single iteration over the entire dataset then corresponds to 7139 mini-batches. The training runs twice over the dataset, thus corresponding to trained for 14278 mini-batches. For training, Adam optimization is used with default parameters.
Evaluating the model on the test-dataset yields the following cross entropy losses and accuracies.
Model | Cross Entropy | Accuracy |
---|---|---|
GRU | 2.1170 | 52.01% |
LSTM | 2.1713 | 51.40% |
Nested LSTM | 2.4950 | 47.10% |
The implementation is available at https://github.com/distillpub/post — memorization-in-rnns .
If you see mistakes or want to suggest changes, please create an issue on GitHub.
Diagrams and text are licensed under Creative Commons Attribution CC-BY 4.0 with the source available on GitHub, unless noted otherwise. The figures that have been reused from other sources don’t fall under this license and can be recognized by a note in their caption: “Figure from …”.
For attribution in academic contexts, please cite this work as
Madsen, "Visualizing memorization in RNNs", Distill, 2019.
BibTeX citation
@article{madsen2019visualizing, author = {Madsen, Andreas}, title = {Visualizing memorization in RNNs}, journal = {Distill}, year = {2019}, note = {https://distill.pub/2019/memorization-in-rnns}, doi = {10.23915/distill.00016} }