TTN: A Domain-Shift Aware Batch Normalization
in Test-Time Adaptation

Hyesu Lim
Qualcomm AI Research, KAIST

Byeonggeun Kim
Qualcomm AI Research

Sungha Choi
Qualcomm AI Research

Introduction

In this post, we introduce a new test-time adaptation method called test-time normalization (TTN), a backpropagation-free adaptation via domain-shift aware batch normalization. Test-time adaptation (TTA) is a task that aims to overcome the performance degradation caused by the domain discrepancy between the source (train) and target (test) inputs.

Figure 1. Domain discrepancy between source and target domains. Used figure from [Hendrycks & Dietterich, 2019].

Test-Time Normalization Layer

As a solution to this possible performance degradation, we propose a new approach called test-time normalization (TTN) layer, which interpolates between conventional batch normalization (which uses source average statistics at test time; CBN) and transductive batch normalization (which uses statistics of current test input batch; TBN). TTN combines the source and current test batch statistics using learnable interpolating weights and uses the combined ones as standardization statistics. 

Figure 2. Test-time normalization (TTN) layer 

Post-training

The key concept of TTN is to combine source and test batch statistics regarding the model’s domain-shift sensitivity. In other words, we put more importance on test batch statistics to the layers and their channels where more domain information is needed for accurate inferences. Specifically, we optimize the interpolating weights (TTN parameters), using a frozen pre-trained model and its labeled source (training) data. In test time, we also freeze the interpolating weights beside the other model parameters and make predictions on target (test) data.

Figure 3. TTN parameters are optimized at post-training phase

Obtain prior. We first measure the domain-shift sensitivity, given a pre-trained model in an off-the-shelf manner. To this end, we introduce a gradient distance score, which measures the difference between two intermediate outputs from clean and domain-shifted (which we simulate via augmentation) inputs by comparing the gradients of the model parameters. Here, our focus lies on the normalization layers, so we only compare the gradients of the affine parameters of the batch normalization layers. Since the difference between two gradients stems from the domain discrepancy, we argue that if the difference is large at (layer l, channel c), then the parameter at (l,c) is intensely affected by the domain shift, i.e., it is handling the domain-related knowledge. We refer to the domain-shift sensitivity level of each layer and channel as prior.

Optimize the interpolating weights. We expect TTN layers to combine the source and test batch statistics with the right amount so that the features are correctly standardized and result in an accurate prediction regardless of the input domain. To do so, we optimize the interpolating weights, alpha, using cross-entropy loss. To inject our intuition, we initialize alpha using the obtained prior and add mean-squared error loss as a regularization term (which prevents alpha from moving too far from the prior A).

Figure 4. An overview of post-training phase; (a) obtaining the prior and (b) optimizing the interpolating weights, alpha

Experimental Results

Experiments. We demonstrated that TTN shows robust and stable performance against unseen target domains in image classification (CIFAR-10-C, CIFAR-100-C, and ImageNet-C) and semantic segmentation (Cityscapes to BDD-100K, Mapiliary, GTA5, and SYNTHIA) tasks. We conducted experiments under realistic but challenging scenarios including,

In all scenarios, we evaluated models in a wide range of test batch sizes; 200, 64, 16, 4, 2, and 1.

Results. TTN flexibly adapts to arbitrary target domains (seen and unseen). Moreover, TTN is broadly applicable to other TTA methods and further pushes their performance, since TTN does not alter training or test-time schemes (backpropagation-free adaptation). Figure 5 shows the selected results. You can find more results and details in our paper.

Figure 5. Selected experiment results. Error rates (↓)  of image classification experiments on CIFAR-10-C using WideResNet-40-2.

Analysis

Visualization of optimized interpolating weights. As shown in Figure 6, the TTN parameters are optimized to use more test statistics in shallow layers and to rely more on source statistics in deep layers. One plausible explanation is that since vision models handle style (or domain) information in shallow layers and content (or semantic) information in deep layers, models require domain-related knowledge more in shallow layers than in deep layers. 

Figure 6. Visualization of optimized interpolating weights

Robustness against the augmentation types. We analyzed how TTN works if the augmentation type used in the post-training phase and test corruption type are misaligned or perfectly aligned. We used one of 15 corruption types in the corruption benchmark (Hendrycks & Dietterich, 2018) as data augmentation while testing on all 15 types of corruption. The following figures show that obtaining prior and optimizing the interpolating weight are invariant to augmentation choices. You can find more details about this experiment in our paper.