TTN: A Domain-Shift Aware Batch Normalization
in Test-Time Adaptation
in Test-Time Adaptation
Introduction
In this post, we introduce a new test-time adaptation method called test-time normalization (TTN), a backpropagation-free adaptation via domain-shift aware batch normalization. Test-time adaptation (TTA) is a task that aims to overcome the performance degradation caused by the domain discrepancy between the source (train) and target (test) inputs.
Figure 1. Domain discrepancy between source and target domains. Used figure from [Hendrycks & Dietterich, 2019].
Test-Time Normalization Layer
As a solution to this possible performance degradation, we propose a new approach called test-time normalization (TTN) layer, which interpolates between conventional batch normalization (which uses source average statistics at test time; CBN) and transductive batch normalization (which uses statistics of current test input batch; TBN). TTN combines the source and current test batch statistics using learnable interpolating weights and uses the combined ones as standardization statistics.
Figure 2. Test-time normalization (TTN) layer
Post-training
The key concept of TTN is to combine source and test batch statistics regarding the model’s domain-shift sensitivity. In other words, we put more importance on test batch statistics to the layers and their channels where more domain information is needed for accurate inferences. Specifically, we optimize the interpolating weights (TTN parameters), using a frozen pre-trained model and its labeled source (training) data. In test time, we also freeze the interpolating weights beside the other model parameters and make predictions on target (test) data.
Figure 3. TTN parameters are optimized at post-training phase
Obtain prior. We first measure the domain-shift sensitivity, given a pre-trained model in an off-the-shelf manner. To this end, we introduce a gradient distance score, which measures the difference between two intermediate outputs from clean and domain-shifted (which we simulate via augmentation) inputs by comparing the gradients of the model parameters. Here, our focus lies on the normalization layers, so we only compare the gradients of the affine parameters of the batch normalization layers. Since the difference between two gradients stems from the domain discrepancy, we argue that if the difference is large at (layer l, channel c), then the parameter at (l,c) is intensely affected by the domain shift, i.e., it is handling the domain-related knowledge. We refer to the domain-shift sensitivity level of each layer and channel as prior.
Optimize the interpolating weights. We expect TTN layers to combine the source and test batch statistics with the right amount so that the features are correctly standardized and result in an accurate prediction regardless of the input domain. To do so, we optimize the interpolating weights, alpha, using cross-entropy loss. To inject our intuition, we initialize alpha using the obtained prior and add mean-squared error loss as a regularization term (which prevents alpha from moving too far from the prior A).
Figure 4. An overview of post-training phase; (a) obtaining the prior and (b) optimizing the interpolating weights, alpha
Experimental Results
Experiments. We demonstrated that TTN shows robust and stable performance against unseen target domains in image classification (CIFAR-10-C, CIFAR-100-C, and ImageNet-C) and semantic segmentation (Cityscapes to BDD-100K, Mapiliary, GTA5, and SYNTHIA) tasks. We conducted experiments under realistic but challenging scenarios including,
single domain adaptation (where models adapt to a single corruption type and is reset for every new corruption type),
continuously changing domain adaptation (where models are not reset for different corruptions which come in a sequence),
mixed domain adaptation (where different target domains are mixed together),
class imbalanced scenario (where the ground-truth labels are not uniformly distributed; non-i.i.d.),
and source domain adaptation (when no distribution shift has occurred).
In all scenarios, we evaluated models in a wide range of test batch sizes; 200, 64, 16, 4, 2, and 1.
Results. TTN flexibly adapts to arbitrary target domains (seen and unseen). Moreover, TTN is broadly applicable to other TTA methods and further pushes their performance, since TTN does not alter training or test-time schemes (backpropagation-free adaptation). Figure 5 shows the selected results. You can find more results and details in our paper.
Figure 5. Selected experiment results. Error rates (↓) of image classification experiments on CIFAR-10-C using WideResNet-40-2.
Analysis
Visualization of optimized interpolating weights. As shown in Figure 6, the TTN parameters are optimized to use more test statistics in shallow layers and to rely more on source statistics in deep layers. One plausible explanation is that since vision models handle style (or domain) information in shallow layers and content (or semantic) information in deep layers, models require domain-related knowledge more in shallow layers than in deep layers.
Figure 6. Visualization of optimized interpolating weights
Robustness against the augmentation types. We analyzed how TTN works if the augmentation type used in the post-training phase and test corruption type are misaligned or perfectly aligned. We used one of 15 corruption types in the corruption benchmark (Hendrycks & Dietterich, 2018) as data augmentation while testing on all 15 types of corruption. The following figures show that obtaining prior and optimizing the interpolating weight are invariant to augmentation choices. You can find more details about this experiment in our paper.