I have been working on building a minimum viable end-to-end network for image classification as a starting point for my project. I recently had the following curious experience: After training my network for the very first time, the model showed over 90% accuracy on my validation set. It looked like I was off to a great start! :-D
Alas, as I took a closer look at my training loop, I recalled that the network had only trained on a handful of batches. What was even worse was that, once I took a look at my predictions to get a feel for where the network performed well, I found that most of them were NaNs! So this network, which hadn’t had an opportunity to learn much at all, and which wasn’t even predicting anything at training, was somehow reaching “high” test performance. My conclusion at that point: “Deep learning is truly magical.”
Upon reflection, I realized that the cause for the high validation-set performance with barely any training might have been the classifier picking up on my imbalanced classes: 7.6% were targets while 92.4% were non-targets. Hence if the network predicted non-target for every single example, that alone would ensure a performance greater than 90%.
Upon investigation and after consultation with my fellow scholar, Alethea and my mentor Johannes, it turns out that the reason for the NaNs in the training set output was that my activations were exploding, as an effect of “exploding gradients” (see this blog post for an introduction to the phenomenon). One way to get exploding gradients is if you fail to normalize your input: If some values in your input are very large, this can cause the learned weights in your network to, correspondingly, become very large, and at some point the network can no longer learn.
It turns out that normalizing your input is quite important… …Important, but a one-line fix, right? Surely you can just call
transform.Normalize() and call it a day? Well, kind of. I did try that and it did indeed fix the NaN-problem: My gradients were no longer exploding and my network was learning and predicting tractable numbers. But just blindly normalizing your input can introduce more subtle problems, which impair your network’s performance without giving you any obvious indication why performance is impaired. Specifically, if your dataset includes extreme outliers, this will compress the important variability, which is comparatively much smaller, making it harder for the network to learn from that information. At worst, if the outliers are many orders of magnitude off, that important variability in your “actual data” can be lost to numerical instability. Hence, it’s best if you can first check for - and remove - outliers, before you normalize your data.
How do you check for outliers? This is where traditional research and “big data” or deep learning approaches diverge. In my PhD research in neuroscience, I checked for outliers by looking at histograms of the data for every single recording electrode, assessing abnormalities, and selecting a cutoff. With datasets sufficiently large for deep learning, this kind of manual checking is no longer feasible.
So you sample. I ended up using
StratifiedShuffleSplit from scikit-learn to get a sample 1/1000 training examples that were selected from all subjects and both classes. In other words, I treated the sampling as a train-test split procedure with a very, very small test set. In this sample, the offending outliers turned out to all be drawn from two specific datafiles in a single subject. Upon investigation, I found that they were driven by very powerful 60-Hz line noise in those electrodes. The probable cause is that those specific contacts had detached from the brain during the recording session. It was interesting for me to observe that this type of “data science” or preprocessing problem is still relevant, also in deep learning.
So this was a frustrating (yet partially funny) experience. The good news is that this type of struggle is at the core of what I need to be learning in the OpenAI Scholars Program: I’m told that knowing how neural networks fail and how to troubleshoot them is one of the most important skills you need in order to be a deep learning practitioner. It helps to remind myself of that when I feel like I’m stuck in an endless recursive loop instead of making “progress”. It’s not just hard and dreary and recursive: it’s building an essential muscle.
To end this story on a meta note, I find that when you work in a way that feels like you’re stomping in place because you have to go “back” all the time, it can help to visualize your path to a completed project starting from the end. Picture your path from project start to finish as traveling on a piece of yarn, which is rolled together into a ball. You start from the outside and unroll it on your way in. In the end, you have a straight line of unrolled yarn, where you can track your progress from start to finish. Now the times in the middle of your project, when you felt like you moved backwards, that was because you were moving left and right on the ball of yarn. But you were still moving forward on your disentangled straight line of yarn. You just couldn’t see that until you were done disentangling it. (I don’t know if that helps, but here’s a ball of yarn: )