Saliency Maps for Deep Learning: Vanilla Gradient
Saliency maps have been getting a lot of attention lately. They are a popular visualization tool for gaining insight into why a deep learning model made an individual decision, such as classifying an image. Major papers such as Dueling DQN and adversarial examples for CNNs use saliency maps in order to convey where their models are focusing their attention.
Saliency maps¹ are usually rendered as a heatmap (see: map), where hotness corresponds to regions that have a big impact on the model’s final decision (see: saliency). They are helpful, for example, when you are frustrated by your model incorrectly classifying a certain datapoint, because you can look at the input features that led to that decision.
The model correctly labels these images as Church, Tractor, and Manta Ray, respectively. As a deep learning practitioner, it’s nice to see the white-hot, high saliency pixels centered on the red tractor - instead of the 1950s beater car - for why the model chose the Tractor class for this image. The desire for this confirmation has only grown in the wake of adversarial attacks, which make us question whether our models are learning real features or merely predictive noise in our dataset.
Today we’re going to take a look at what is known as Vanilla Gradient. It’s the original saliency map algorithm for supervised deep learning from Simonyan et. al. (2013). Despite showing its age in the form of noisy maps, Vanilla Gradient has been shown to be relatively robust, while other techniques have failed quantitative evaluation². Furthermore, it’s the simplest algorithm among gradient-based³ approaches and very fast to run, which makes it a great place to start to understand how saliency maps work.
Running a deep net on a CPU is slow on modern deep learning architectures, so we’ll use a relatively simple Convolutional Neural Network (LeNet 5) trained on the 60k datapoint MNIST instead of the 1.4m datapoint ImageNet (ILSRVRC).
I’m assuming some familiarity with deep learning, including how CNNs and backpropagation work. Vanilla Gradient is remarkably straightforward from there. If you find yourself getting a little confused, the above links are the fastest refreshers I’ve found.
The Setup
Below is our main.py
file, which we run to generate our saliency maps. Notice our save_vanilla_gradient
function at the bottom only requires three arguments: our trained network, data, and labels.
Check my GitHub repo for the data processing and network training code. They’re standard for MNIST, so we’ll skip them for now to get to the good stuff.
Implementing Vanilla Gradient
In the spirit of Karpathy’s “Yes, you should understand backprop”, we’ll start from the basics in our implementation. Like backprop, saliency maps are a leaky abstraction. When dealing with leaky abstractions, it’s important to understand the mechanisms you’re working with because errors are not straightforward to trace to a line of code or module.
The way you consistently find this species of sneaky bug is by developing a sense of what range the values should be at for each computational step and noticing discrepancies. So, no TensorFlow, no PyTorch. These high-level frameworks abstract many steps to the GPU, which is great for performance, but limits the familiarity-building process. Instead, we’ll be building Vanilla Gradient from scratch in NumPy, with every step observable. This way you’ll know what to look for later when you need to dive deep into model code written in a high-level framework.
Vanilla Gradient algorithm:
1) forward pass with data
2) backward pass to input layer to get the gradient
3) render the gradient as a normalized heatmap
Let’s jump into the code. Below are the first two steps, the core logic of Vanilla Gradient.
Backpropagation normally stops at the second layer during training for efficiency as you can’t change your input. Crucially, however, Vanilla Gradient continues to backprop to the input layer to see which pixels would affect our output the most. That is, to see which pixels are most salient. Simple and elegant, no?
The backpropagation step here gives us good saliency clues because it calculates the gradient of the given output class with respect to the input image. The gradient is just a list of derivatives, one for each pixel. A derivative here is essentially saying:
“For every amount you change this pixel, I change the output probability of the class this much”
Those pixel-wise derivative values are what we use to generate our heatmap. The high positive values are near-red, the high negative values are near-blue, and the pixels that have a derivative near zero (thus not making much of an impact on the output class when the pixel is changed) are near-white.
You could even calculate the saliency map of an image to a class other than its label. For example, we see here what parts of this image of a 7
contribute to the 1
class.
The last step, rendering, isn’t conceptually important, but I encourage you to take a peak at my GitHub repo if you’re curious to understand how it works.
On colormaps. One interesting rendering question is what colormap to use. The original Vanilla Gradient paper used a white-spectrum colormap, while we use a red-white-blue colormap. The advantage of a diverging colormap such as red-white-blue is that we can better capture the difference between positive and negative values. This is useful in white-digit-on-black-background MNIST, as positive derivatives indicate positive probability impact (and vice versa). However, in ImageNet it turns out that the implications of signage is context-dependent, so researchers have found the absolute value of the gradient and sequential color maps like white-spectrum to be most clear.
Conclusion
Vanilla Gradient is a saliency map algorithm that helps you better understand your deep learning model’s decision-making. By deep diving today, you’re better equipped to debug and interpret your classifiers.
I encourage you to play around with the code, especially in a debugger. Setup takes less than 5 minutes. There’s no better way to strengthen your learning from here than to get your hands dirty and generate some saliency maps!
I appreciate your comments, pull-requests, and GitHub issues as they will help me create an even better treatment for future essays.
Thanks to Ben Mann, Yohann Abittan, Kamil Michnicki, Lukas Ferrer, Michael Staub, Sam Turner, Spencer Creasey, Steve Phillips, and Taylor Kulp-Mcdowall for reading drafts of this.
Footnotes
[1] Saliency maps are also known in the literature as pixel-attribution maps, attribution maps, and sensitivity maps, but in recent years it looks like saliency is winning the jargon war. Long live the king.
[2] About ten algorithms for generating saliency maps for deep nets have since been published over the past five years, demoing ‘sharper’ maps that, excitingly, put even more dots on the tractor. However, some of these algorithms have failed a new batch of randomization tests, showing that on fresh, randomly-initialized networks they create basically the same visualizations as they do on fully-trained networks. Oops.
[3] The other category of saliency maps is known as perturbation-based methods. Essentially they involve messing with the input in various ways (blurs, blacking-out pixels, etc.) and seeing how much that affects classification output.