Luke Roberto, David Slayback, David Klee
Our final project for CS7250 was an online tool that allows people to visualize the features learned within a CNN! Below you can watch a video about the tool. The tool can be found at this link.
Deep Neural Networks (DNNs) are a technique that has recently been applied to a number of fields, including computer vision, AI, text processing, and translation, with great success. The main advantage of DNNs over previous machine learning (ML) techniques is that they abstract away one of the hardest parts of developing an application: feature selection. To some extent, the architecture of any neural network must be catered to its specific task, but this is a far cry from picking specific, relevant features from among hundreds if not millions of possible choices.
While this is a great advantage in terms of rapidly developing usable applications, it also leads to DNNs being difficult to interpret. Instead of being able to point to “housing price per square foot” or “concentration of edges in the upper corner of a picture” as features, we are often left saying “it works” without understanding why. What is a DNN learning? How? What does each layer represent? Each neuron? Each function? If we can’t answer these questions, we can’t know why DNNs make the mistakes they do, or when, or how we might fix them.
The purpose of this project is to interactively visualize convolutional neural networks (CNNs), a subclass of DNNs, in an intuitive format that allows for deep learning practitioners to understand their models beyond a black box level and develop an intuition of the method to their madness. We will prioritize visualizing pretrained models, but ideally would like to visualize the training process as well. With this visualization (and potential extensions), we hope to increase understanding of and trust in these heretofore arcane techniques so as to better serve human-compatible ML and AI.
Since we are visualizing a neural network, our definition of the “dataset” will refer to the combination of the network itself and the data that it is trained on. The initial dataset we decided to work with was the CIFAR-10 classification dataset, linked here, and a pretrained CNN to visualize. We chose this dataset because of the small number of classes the the network needs to classify, and because a network does not have to be too large in order to perform well on it. The dataset is quite easy to download, and is a relatively small file size. The instructions on the website show you how to correctly load the dataset in python as well. An example of the classes and corresponding pictures are shown below.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. This means that there is categorical data (the 10 classes), there is quantitative data (activations, kernels, and feature maps in a network), and there is also ordinal data (the layer index in a network).
The person we interviewed was a fellow researcher in our lab. He specifically works in deep reinforcement learning, with an emphasis in problems relating to abstraction, sparse rewards, and long time horizons. The context of his research is robotic grasping and manipulation, so vision is a major component. He uses CNNs in order to help a robot learn visuomotor policies based off dense, visual inputs.
The interview went rather well, there was a good mix of expected conversation and novel points brought up by the interviewee(?). Most of our questions were targeted at how do CNNs get used by an average researcher, so we wanted insight into what methods are currently being used to visualize the data. We expected him to say that features and filters learned by the CNN tend to be pretty complex and hard to interpret (especially deeper into the network). One topic we talked about that was a little less expected was the use of this project to debugging. When asked how he debugs his models, the methods he listed were pretty low level with much use of matplotlib and just looking at the surface layers. He even noted that a lot of the time he would make networks deeper, hoping that they would learn around whatever issues there were. This gave us a much better perspective on how a user would use a visualization tool for these types of networks.
We had a multi-step design process that started with several group sketches. Each team member highlighted several traits that should be visualized including hierarchy, feature maps, kernels, and the structure of the data. Once we had all of these sketches (first picture), we were able to discuss traits about each sketch that we found most salient to the task analysis.
We converged on an initial mock up that we pitched to the class shown below (second and third pictures). This example served as an anchor that we would use for the rest of the process. Most of the structure from this initial design has held true
For the actual programming, we split the work into backend and frontend, with one developer working on a Python backend to interface with the deep learning models/dataset and return parsed data, and two developers working on different parts of the frontend to display that data.
Once we started writing actual code, we found that we had picked simple enough layouts that putting it all together was relatively easy. We started to realize that the original idea of the class-based statistics on the left of the main screen was not very helpful, so we decided to replace it with a chord diagram that showed the misclassifications of the network. Along with this we added a bar chart to show the breakdown of classifications by class. We also realized that the networks we were dealing with were quite large, so we added a scroll bar in the top panel so that users could view smaller subsets of layers at a time.
In the future, we'd like to improve the homescreen UI, as well as offer a larger variety of datasets and networks. Ideally, we'd like to allow users to input their own pre-trained or partially-trained models in place of the current dropdown of presets, but that may be difficult due to how models are often defined specifically to particular formats of particular datasets.
Additionally, we'd like to include the ability to visualize the training process as well, but this would be more computationally demanding and add an additional level of complexity to the visualization.
Overall, we're quite proud of the final tool and feel it offers useful insights for CNN practitioners in terms of understanding the errors that their networks make and how those errors propagate through the network structure. We intend to open source the project to allow others to use and/or contribute to the work.