Fighting Fire with Fire: Avoiding DNN Shortcuts through Priming

Accepted by ICML 2022

Abstract

Across applications spanning supervised classification and sequential control, deep learning has been reported to find ``shortcut'' solutions that fail catastrophically under minor changes in the data distribution. In this paper, we show empirically that DNNs can be coaxed to avoid poor shortcuts by providing an additional ``priming'' feature computed from key input features, usually a coarse output estimate. Priming relies on approximate domain knowledge of these task-relevant key input features, which is often easy to obtain in practical settings. For example, one might prioritize recent frames over past frames in a video input for visual imitation learning, or salient foreground over background pixels for image classification. On NICO image classification, MuJoCo continuous control, and CARLA autonomous driving, our priming strategy works significantly better than several popular state-of-the-art approaches for feature selection and data augmentation. We connect these empirical findings to recent theoretical results on DNN optimization, and argue theoretically that priming distracts the optimizer away from poor shortcuts by creating better, simpler shortcuts.

Shortcut Learning in DNNs

Image Classification

Shortcut: Background Contexts

Behavioral Cloning

Shortcut: Previous Action (Copycat Issue)

The shortcut issue is caused by the property that DNNs prefer to take the simpler solutions rather than the intended ones, which harms the generalization performance of the models!

Methodology: PrimeNet

Motivation:

Besides “labeling” the scene as containing a “zebra”, the parent also points to the animal and its stripes.

Humans often manage to generalize far beyond the training data by exploiting such additional knowledge!

PrimeNet

The key input k(x) is extracted from the raw input signal according to the additional domain knowledge, and generates a coarse estimation of the final output, called Prime Variable. The raw input and Prime Variable are fed into Main Module to make the shortcut-free prediction.

Prime Variable serves as a good shortcut to prime the main module away from the bad shortcuts!

Experiment Results

Testing beds

Image Classification

NICO

Behavioral Cloning

CARLA

Behavioral Cloning

MoJoCo

Evaluation Results

Image Classification

Table 1. Image classification accuracies on the NICO dataset.

Behavioral Cloning

Table 2. The online testing results in the environments of CARLA and MuJoCo.

If you think our work is helpful, please consider citing our paper in your publications:

@article{wen2022priming,

author = {Chuan Wen and Jianing Qian and Jierui Lin and Jiaye Teng and Dinesh Jayaraman and Yang Gao},

journal = {ICML},

title = {Fighting Fire with Fire: Avoiding DNN Shortcuts through Priming},

year = {2022}

}