Adaptive Risk Minimization: Learning to Adapt to Domain Shift

Marvin Zhang*, Henrik Marklund*, Nikita Dhawan*, Abhishek Gupta, Sergey Levine, Chelsea Finn

paper / code / blog posts (somewhat outdated)

Abstract

A fundamental assumption of most machine learning algorithms is that the training and test data are drawn from the same underlying distribution. However, this assumption is violated in almost all practical applications: machine learning systems are regularly tested under distribution shift, due to changing temporal correlations, atypical end users, or other factors. In this work, we consider the problem setting of domain generalization, where the training data are structured into domains and there may be multiple test time shifts, corresponding to new domains or domain distributions. Most prior methods aim to learn a single robust model or invariant feature space that performs well on all domains. In contrast, we aim to learn models that adapt at test time to domain shift using unlabeled test points. Our primary contribution is to introduce the framework of adaptive risk minimization (ARM), in which models are directly optimized for effective adaptation to shift by learning to adapt on the training domains. Compared to prior methods for robustness, invariance, and adaptation, ARM methods provide performance gains of 1-4% test accuracy on a number of image classification problems exhibiting domain shift.

Overview

The standard assumption in empirical risk minimization (ERM) is that the data distribution at test time will match the training distribution. When this assumption does not hold, i.e., when there is distribution shift, the performance of standard ERM methods can deteriorate significantly. As an example we will study in detail, consider a handwriting classification system that, after training on data from past users, is deployed to new end users. Each new user represents a new test distribution that differs from the training distribution. Thus, each test setting involves dealing with shift. This example can be characterized as an instance of the domain generalization problem, in which the training data are provided in domains and distributions at test time will represent new domains. Constructing training domains in practice is generally accomplished by using meta-data, which exists for many commonly used datasets. Thus, this domain assumption is applicable for a wide range of realistic distribution shift problems

Imagine building a handwriting classification machine learning system that is ultimately deployed to end users. Certain users may be more challenging due to greater input shift, leading to greater misclassification. In these cases, performing unlabeled adaptation with a particular user's other examples may allow the classifier to achieve greater accuracy for that user.

In this work, we focus on methods that aim to adapt at test time to domain shift. To do so, we study problems in which it is both feasible and helpful (and perhaps even necessary) to assume access to a batch or stream of inputs at test time. Leveraging this test assumption does not require labels for any test data and is feasible in many practical setups. For example, for handwriting classification, we do not access only single handwritten characters from an end user, but rather collections of characters such as sentences or paragraphs.

Unlabeled adaptation has been shown empirically to be useful for distribution shift problems, such as for dealing with image corruptions. Taking inspiration from these findings, we propose and evaluate on a number of problems for which adaptation is beneficial in dealing with domain shift.

We introduce the framework of adaptive risk minimization (ARM), which proposes the following objective: optimize the model such that it can maximally leverage the unlabeled adaptation phase to handle domain shift. To do so, we instantiate a set of methods that, given a set of training domains, meta-learns a model that is adaptable to these domains. These methods are straightforward extensions of existing meta-learning approaches, thereby demonstrating that tools from the meta-learning toolkit can be readily adapted to tackle domain shift.

From the training dataset, we construct training distributions that simulate shift. For example, a training distribution may place uniform mass on only a single user's examples. We use these distributions to learn a model that is adaptable to domain shift via a form of meta-learning, where an adaptation model has the opportunity to adapt the model parameters using the unlabeled examples. Assuming the appropriate differentiabilities, this allows us to meta-train the model for post adaptation performance. However, this adaptation is performed using unlabeled data, mimicking the test time adaptation we wish to perform.

We propose three methods for the ARM problem setting, inspired by prior methods in contextual meta-learning, test time adaptation, and gradient based meta-learning. First, the ARM-CML method meta-learns a context network, which processes each unlabeled example separately to produce contexts, which are then averaged together and used as an additional input to the model. Second, in the ARM-BN method, the model is trained to adapt by computing batch normalization statistics on the batch of inputs, replacing the standard test time procedure of using the running statistics computed over the course of training. Finally, in the ARM-LL method, the goal is to learn model parameters that are amenable to gradient updates on a unsupervised loss function in order to quickly adapt to the test batch. For full details on each ARM method, please refer to the paper.

Our experiments demonstrate that the proposed ARM methods, by leveraging the meta-training phase, are able to consistently outperform prior methods for handling shift by 1-4% test accuracy in image classification settings including benchmarks for federated learning and image classifier robustness.

Experiments

We propose four image classification problems, which we believe can supplement existing benchmarks for domain shift. A key characteristic of the problems presented here is the potential for adaptation to improve test performance, and this differs from prior benchmarks. We also present results on datasets from the WILDS benchmark.

First, we study a modified version of MNIST where images are rotated in 10 degree increments, from 0 to 130 degrees. We use only 108 training data points for each of the 2 smallest domains (120 and 130 degrees), and 324 points each for rotations 90 to 110, whereas the overall training set contains 32292 points. In this setting, we hypothesize that adaptation can specialize the model to specific domains, in particular the rare domains in the training set. At test time, we generate images from the MNIST test set with a certain rotation, and we consider each method's worst case and average accuracy across domains.

Second, we use the FEMNIST dataset, a version of the extended MNIST (EMNIST) dataset that associates each handwritten character with the user that wrote the character. EMNIST consists of images of handwritten uppercase and lowercase letters, in addition to digits. We construct a training set of 62732 examples from 262 users, where the smallest user has 104 examples. The test set consists of 8439 examples from 35 users not seen at training time, and the smallest user has 140 examples. We measure each method's worst case and average test accuracy across users. As discussed below, adaptation may help for this problem for specializing the model and resolving ambiguous data points.

Finally, we evaluate the proposed methods and all comparisons on modified versions of CIFAR-10-C and Tiny ImageNet-C, which augment the CIFAR-10 and Tiny ImageNet datasets, respectively, with common image corruptions that vary in type and severity. Prior work has shown that carefully designed test time adaptation procedures are effective for dealing with corruptions. One possible reason for this phenomenon is that convolutional networks typically rely on texture, which is distorted by corruptions, thus adaptation can help the model to specialize to each corruption type. We modify the training protocol to fit into the ARM problem setting by using a set of 56 corruptions for the training data, and we define each corruption to be a domain. We use a disjoint set of 22 corruptions for the test data, and we measure worst case and average accuracy across the test corruptions.

In this table, we summarize worst case (WC) and average (Avg) top 1 accuracy on rotated MNIST, FEMNIST, CIFAR-10-C, and Tiny ImageNet-C, where means and standard errors are reported across three separate runs of each method. Horizontal lines separate methods that make use of (from top to bottom): neither, training domains, test batches, or both. ARM methods consistently achieve greater robustness, measured by WC, and Avg performance compared to prior methods.

Across all of the proposed problems, ARM methods increase both worst case and average accuracy compared to all other methods. ARM-CML performs well across all tasks, and despite its simplicity, ARM-BN achieves the best performance overall on the corrupted image testbeds, demonstrating the effectiveness of meta-training on top of an already strong adaptation procedure. Compared to other prior methods, such as those for test time adaptation, ARM methods are comparatively less reliant on favorable inductive biases and consistently attain better results.

In the streaming setting, ARM methods reach strong performance on rotated MNIST (left) and Tiny ImageNet-C (right), after fewer than 10 and 50 data points, respectively, despite meta-training with batch sizes of 50 and 100. This highlights that the trained models are able to adapt with small test batches and can operate successfully in the standard streaming evaluation setting.

We investigate the effectiveness of ARM methods in the streaming test time setting, where test points are observed one at a time rather than in batches. In the paper, we also study whether the training domain assumption can be loosened, by instead using unsupervised learning techniques to discover domain structure in the training data.

When we cannot access a batch of test points all at once, and instead the points are observed in a streaming fashion, we can augment the proposed ARM methods to perform sequential model updates. For example, ARM-CML and ARM-BN can update their average context and normalization statistics, respectively, after observing each new test point. In the plots above, we study this test setting for the rotated MNIST and Tiny ImageNet-C problems. We see that both models trained with ARM-CML and ARM-BN are able to achieve near their original worst case and average accuracy within observing 10 and 50 data points for rotated MNIST and Tiny ImageNet-C, respectively, well before the training batch sizes of 50 and 100. This result demonstrates that ARM methods are applicable for problems where test points must be observed one at a time, provided that the model is permitted to adapt using each point.

We present a qualitative example of how ARM methods can improve test accuracy by adapting to specific users. We visualize a batch of 50 examples from a random FEMNIST test user, and we highlight an ambiguous example. An ERM trained model and an ARM-CML trained model, when only given a test batch size of 2, incorrectly classify this example as "2". However, when given access to the entire batch of 50 images, which contain examples of class "2" and "a" from this user, the ARM-CML trained model successfully adapts its prediction to "a", which is the correct label. In general, we find that most examples of adaptation in FEMNIST occur for similarly ambiguous examples, e.g., "l" versus "I", though not all examples were interpretable.

Results on the WILDS image testbeds. Different methods are best suited for different problems, motivating the need for a wide range of methods. ARM-BN struggles on FMoW but performs well on the other datasets, in particular RxRx1.

Finally, we present results on the WILDS benchmark. We evaluate BN adaptation and ARM-BN on these testbeds. We see that, on these real world distribution shift problems, different methods perform well for different problems. CORAL, a method for invariance, performs best on the iWildCam animal classification problem, whereas no methods outperform ERM by a significant margin on the FMoW or PovertyMap satellite imagery problems. ARM-BN performs particularly poorly on the FMoW problem. However, it performs well on PovertyMap and significantly improves performance on the RxRx1 problem of treatment classification from medical images. On the other medical imagery problem of Camelyon17 tumor identification, adaptation in general boosts performance dramatically. These results indicate the need to consider a wide range of tools, including meta-learning and adaptation, for combating distribution shift.