Learning with no target data
Best Viewed at https://sites.google.com/andrew.cmu.edu/multigan-distillation (CMU access)
Code available here
Best Viewed at https://sites.google.com/andrew.cmu.edu/multigan-distillation (CMU access)
Code available here
In the previous section, we saw distillation in the presence of a miner/selector network that refines the distribution captured by the source models. In this section, we take the more classical route of distillation, i.e. through distribution matching. Specifically, here we study the two most commonly used loss functions - Lp norm and Perceptual loss. It is worth mentioning that in contrast to the previous section, here the source GANs are kept frozen during the distillation process.
Dataset. We experiment on the FFHQ dataset and LSUN Cats dataset. However, here we do not use the dataset itself but distill generators trained on these datasets.
Architecture. We use pre-trained source generators and train a target generator, both based on the StyleGAN-2 architecture. As seen in Fig 1.1.1, there are 2 latent spaces in StyleGAN-2 that we're interested in - the input space (512-dimensional z-space) and the style space (512-dimensional w-space) corresponding to the output of the mapping network. We experiment with distillation across the image space using either of the two latent spaces as the input and determine which is more conducive to successful distillation. Also, note that there is an additional constant input (c1) that is fed to the generator. This drives the structure of the generated image. In the case of the multi-source GAN setting, we incorporate source-specific learnable scale and shift parameters for the style latent codes and incorporate source-specific constant inputs to enable the network to distinguish between the different modes of data.
Figure 1.1.1. We use a StyleGAN-2 architecture.
The training strategy is summarized in Fig. 1.1.2b-c.
Single-source GAN. As shown in Fig. 1.1.2b below, we sample the source generator (G_src) to generate ground-truth "real" images, and pass the same latent input (z in Fig. 1.1.2b and w in Fig. 1.1.2c) to the target generator (G_tgt). We employ a weighted combination of L2 loss and perceptual loss to train this network.The intuition to explore both these latent spaces is described in the experiments that follow.
Multi-source GANs. In this case, due to computational constraints, we only explore the training paradigm shown in Fig. 1.1.2c (excluding perceptual loss). Considering that the target model may need to capture different modes of data corresponding to different source distributions, we learn source-specific scale (γ) and shift (β) parameters to transform the w vector before passing it to the target generator. This simple learnable transformation is expected to separate each source distribution into a different subspace making it easier for the same target generator to learn both the distributions. For distillation, we compute the average L2 loss between each corresponding source-target image pair.
Figure 1.1.2. Overview of our proposed training strategies. a) A typical StyleGAN generator architecture depicting a mapping network and a generator. b) Classical distillation using a common z-space input and an L2 loss (euclidean distance) at the images. c) Modified distillation framework employing a common w-space using an L2 loss and a perceptual loss (euclidean distance between 2048D InceptionV3 features) on the images.
In this section, we first explore distillation using a single source generator as shown in Fig. 1.1.2b-c. We identify the importance of local-smoothness property that leads to successful distillation using limited examples. We observe that the style space (w-space) is more locally smooth than the input space (z-space), following which we describe a distillation paradigm with the w-space. Finally, we incorporate perceptual loss to improve the quality of the generated outputs.
As a baseline experiment, we distill a StyleGAN-2 generator pre-trained on the FFHQ dataset. The distillation procedure is depicted in Fig. 1.1.2b. For simplicity, and to validate the distillation procedure, we keep the generators identical in this experiment. Note that distillation has practical applications in scenarios such as model compression, however, our main goal is to understand the effect of distillation while its implications on model compression are kept for future work.
In particular, we randomly sample a batch of z-vectors at the input space (Fig. 1.1.2b) and obtain the corresponding images from the source and the target generators. The target generator is penalized using an L2 loss between the corresponding images (while the source generator is kept frozen).
Fig. 2.2.1a shows the training images generated by the target generator after loss converges (Fig. 2.2.1c). We observe that while the source generator is able to generate realistic images (Fig 2.2.1b), the target generator is only able to synthesize images that look like an average face. We see that the generated images are overly smooth and have little to no diversity.
In a way, this approach considers StyleGAN-2 as a black box generator network that takes a random z latent vector and generates an image corresponding to the data distribution.
Figure 2.2.1a. Images generated by the target model trained using an L2 distillation loss corresponding to the ground-truth images on the right.
Figure 2.2.1b. Images generated by the source GAN corresponding to the generated images on the left.
Figure. 2.2.1c. Training loss (L2) curve.
To validate our implementation, we overfit the network to a set of 20 randomly selected z latent vectors. Fig 2.2.2.a-b demonstrates that the network is able to generate images from a small training set.
Figure 2.2.2a. Images generated by the target generator trained with 20 randomly selected z-samples.
Figure 2.2.2b. Corresponding Images generated by the source generator.
Figure. 2.2.2c. Training loss (L2) curve. The curve is smoother than that in Fig. 2.2.1c since here we overfit to only 20 samples.
As a natural next step, we repeat the experiment for a limited set of 1k training z vectors. Clearly (Fig 2.2.2.d-e), the model is able to fit 1k images as well, however takes longer time to show signs of convergence (similar loss range as Fig. 2.2.2c).
Figure 2.2.2d. Training images from the target model trained with 1k selected z-samples using an L2 loss.
Figure 2.2.2e. Training images from the source GAN corresponding to the 1000 selected z-samples.
Figure. 2.2.2f. Training loss (L2) curve. The curve is smoother than that in Fig. 2.2.1c since here we overfit to only 1000 samples.
However, when we generate test images (random z input), from the target model trained with 1000 fixed z-vectors, we get results as shown in Fig. 2.2.2g (the corresponding source images are in Fig. 2.2.2h). Clearly, the model trained in this manner is incapable of generalizing to new values of z.
Figure 2.2.2g. Validation images generated by the target model trained with 1k z-samples using an L2 loss.
Figure 2.2.2h. Validation images from the source GAN corresponding to the target images shown on the left.
Our first intuition regarding the non-generalizability of the generator in Sec. 2.2.2 is that perhaps the number of samples (1000) used to distill the generator is too little, and therefore the target generator overfits the 1000 samples. However, if this were true, distillation should have worked in the first case - i.e. distilling with infinitely many z-values as in Sec. 2.2.1. Hence we debug the model further.
We wonder whether the input space itself is highly complex due to which the model exhibits overfitting. To verify this intuition, we generate samples from the source model using z-samples with different variances. Fig. 2.2.3a-c shows images corresponding to z-samples generated from a normal distribution with variances of 0, 0.0001, and 0.01 respectively. Clearly, in the first case, there is no variance in the generated images, however, even for an insignificant variance of 0.0001, we see diverse images. There is a similar diversity on using a larger variance of 0.01.
This experiment suggests that the input z-space is not locally smooth - even a small deviation from the mean can cause significant change to the generated images. Clearly, one can not expect any generalization on such a space.
Figure 2.2.3a. Images generated by sampling z ~ N(0, 0).
Figure 2.2.3b. Images generated by sampling z ~ N(0, 0.0001).
Figure 2.2.3c. Images generated by sampling z ~ N(0, 0.01).
This leads us to question whether the assumption of the StyleGAN-2 generator as being a black box is practical. In particular, we note that the mapping network in StyleGAN-2 that maps the z-space (input space) to w-space (style space) is highly non-linear which could be the cause of the lack of local-smoothness in the z-space. We plot images by varying inputs in the w-space, to validate this hypothesis.
In particular, we sample w-vectors with different truncation values [2] as shown in Fig. 2.2.3d-f. We find that increasing the truncation value does not drastically change diversity as seen in the z-space above. However, reasonably increasing the truncation value still yields diverse images. Therefore, we argue that the w-space (style space) is much more locally smooth than the z-space which would be conducive to successful distillation.
Figure 2.2.3d. Images generated by sampling w-space with truncation 0.1.
Figure 2.2.3e. Images generated by sampling w-space with truncation 0.2.
Figure 2.2.3f. Images generated by sampling w-space with truncation 0.5.
Having observed a higher degree of smoothness at the w-space, we argue that performing distillation using the w-space input would enable better generalization. This is due to the fact that since the w-space is smooth, the network can learn with limited data points and interpolate reliably for the remainder of the space owing to smoothness.
We model this as follows. Given a set of z-vectors at the input space, we first map them to the corresponding w-vectors using the mapping network of the source generator. This set of w-vectors is passed to both, the source generator and the target generator as shown in Fig. 1.1.2c. We compute L2 and perceptual loss between the corresponding images from the source and target generators to train the target generator. In this process, the source model (including the mapping network) is kept frozen. For the purpose of a single source GAN, we do not employ learnable scale (γ) and shift (β) parameters. We train the model with 100k z-vectors sampled randomly at the beginning and fixed throughout the training.
The results are shown in Fig. 2.2.4a-d. Specifically, Fig. 2.2.4a shows the training images generated by the target generator (Fig. 2.2.4b shows the source images used for supervision). We find high-quality results which indicate that the model fits well onto the training data. More interestingly, we find that the target model also generalizes to test data (randomly sampled z). Specifically, we randomly sample z-vectors (instead of the 100k training z-vectors used for training) and obtain the corresponding w-vectors from the source mapping network. These vectors are then passed through the target generator to synthesize novel images. Fig. 2.2.4c shows the images generated in this manner which follows a distribution similar to the original source images (Fig. 2.2.4d). Note that the generated images faithfully represent the source data, although are a little blurry. This validates our intuition that the distilled model is able to generalize better due to the smooth latent w-space which is more conducive to distillation.
Figure 2.2.4a. Predicted images from the target generator when we input the w-vectors used in training.
Figure 2.2.4b. Corresponding images from the source generator.
Figure 2.2.4c. Predicted images from the target generator when we input w obtained from randomly sampled z.
Figure 2.2.4d. Corresponding images from the source generator.
In our experiment in Sec. 2.2.4, we use a weighted combination of L2 loss and perceptual loss, with the weights being 1 and 0.2 respectively. The model in Sec. 2.2.4 is first trained using the L2 loss alone until convergence, and then fine-tuned with perceptual loss.
As shown in Fig. 2.2.5a-b, the model when tuned with perceptual loss (Fig. 2.2.5b) exhibits sharper results than training with only L2 loss (Fig. 2.2.5a). For instance, it captures finer details around the hair, eyes, teeth etc. This suggests that using an unweighted L2 loss causes overly smooth results as the model equally weighs all regions in the image. However, the perceptual loss guides the model by penalizing more on the salient regions.
Figure 2.2.5a. Images generated by the target generator trained with only L2 loss.
Figure 2.2.5b. Images generated by the target generator trained with L2 + Perceptual loss.
Figure 2.2.5c. Images generated by the source generator.
To summarize, we find that the generator distilled using w-space inputs synthesizes high-quality images. Here, we do not use any adversarial training, and a simple distillation could suffice. We find that the key to successfully distilling a GAN is to select the right latent space which enables local smoothness. Note that This setup is highly useful in many knowledge transfer applications for GANs because this does not require adversarial training, which usually comes with the associated training difficulties.
We compared the outputs of our distilled target generator with that of the source generator (for same input latents w), and we observed reasonable consistency between the generated images. Thus, we can say that the target generator will exhibit similar behavior (such as w-space interpolations, style mixing, etc.) as the source generator, albeit with some loss in quality. Improving the quality of these generations by adding additional supervision such as adversarial losses would be an interesting direction to explore in the future.
We now extend the distillation framework to the multi-source setting. We use 2 pretrained StyleGAN-2 generators (FFHQ and LSUN Cats). The setup is identical to that shown in Fig. 1.1.2c, in that we use multiple source generators (G_src) and their corresponding mapping networks (M_src) and train a randomly initialized target generator G_tgt. We will be referring to the two source distributions as src1 and src2. Note that compared to the miner approach discussed in the "Learning with Limited Target Data" section, here we fuse the two distributions into a single GAN, requiring lesser storage and computation at inference.
We sample a set of 100k z-vectors which are passed to both the source mapping networks (M_src1, M_src2). Thus we obtain corresponding w-vectors w_src1 = M_src1(z), w_src2 = M_src2(z) which are fed to the source generators and the ground-truth images G_src1(w_src1) and G_src2(w_src2) are extracted. Further, these w-vectors are transformed through separate learnable scale (γ) and shift (β) parameters (512D tensors; one scale and shift value for each dimension of w-vector) that are learned separately for each source domain. Thus, we obtain two sets of style vectors as w_1 = γ_1 * w_src1 + β_1 and w_2 = γ_2 * w_src2 + β_2, each capturing the style subspace of the corresponding data distribution.
While the style code generated in this manner would help the target network distinguish between the styles of the two source distributions, we note that the constant input (Fig. 1.1.1) is responsible for giving structure to the generated image. Thus, for learning different modes of source data, we must use different constant inputs in the target generator while generating the corresponding images. We handle this by using the constant inputs learned by the corresponding source models, c_src1 and c_src2. We do not impose a learnable transformation over these style codes.
To summarize, for generating an image from the target generator corresponding to a source distribution src1, we obtain the style vector w_1 and the constant input c_src1. These are passed to the target generator to generate images G_tgt(w_1, c_src1). Finally, we compute the L2 loss between the generated images and the ground-truth images G_src1(w_src1) from the source model src1. The same process is employed for src2. For backpropagation, we average the L2 losses arising from the corresponding pairs of images. Note that this implicitly assumes equal weightage to each of the source distributions, and the exploration of weighting schemes is left as future work.
Due to computational constraints, we were unable to incorporate the perceptual loss for multi-source distillation as it requires a large GPU memory to perform backpropagation. However, we believe that incorporating perceptual loss would further improve the quality of results, especially for hard to learn distributions such as LSUN Cats that have a lot of diversity in poses and backgrounds.
We train a single StyleGAN-2 model using the distillation strategy mentioned in Sec. 3.1 above. Below, we visualize randomly sampled results from the distilled target generator network. Specifically, during inference, we first sample random z vectors which are passed to the corresponding source mapping networks w_src1 = M_src1(z) and w_src2 = M_src2(z). These are then transformed using the learned scale and shift parameters w_1 = γ_1 * w_src1 + β_1 and w_2 = γ_2 * w_src2 + β_2, which are used along with the corresponding source-learned constants c_src1 and c_src2, to generate the target images as G_tgt(w_1, c_src1) and G_tgt(w_2, c_src2).
The results are shown in Fig. 3.2a-d with randomly sampled z-vectors. Specifically, in the case of FFHQ faces (Fig. 3.2a and Fig. 3.2b), we see a faithful reconstruction of the source images although a bit blurry. Nevertheless, these images are unhindered by the presence of the Cat distribution. This is attributed to the fact that we use two source-specific attributes: 1) firstly, we use different constant inputs to the target generator while generating each distribution which provides a different structure to the generated data, and, 2) secondly, we learn different scale and shift parameters for the style input for each distribution which ensures that different styles are learned by the target model. Thus, we are able to successfully fuse the two source GANs into a single target GAN which we aimed for.
In the case of the generated Cat images, we observe results having a bokeh effect around the cat faces. While these images look reasonable in terms of the cat pose (object of interest) and overall background information, they exhibit a significant blur in comparison to the FFHQ faces. We speculate that the reason for this (bokeh) effect is that the LSUN cat images contain high diversity in poses, background, and foreground textures leading to diverse details. As a result, it is hard to model such high variability in details using an L2 loss alone.
Nevertheless, the target generator is still able to understand the local semantics around the cat faces as the facial region is relatively more consistent across the images (e.g. most images have the cat looking straight at the camera, therefore capturing both the eyes, ears, nose etc.). Furthermore, the semantics of the cat faces also match with those of humans in the FFHQ distribution - enabling the model to leverage the regularity in the facial structure and extract finer details.
Figure 3.2a. FFHQ images generated by the target generator on randomly sampled z-vectors.
Figure 3.2b. Corresponding FFHQ images generated by the source generator.
Figure 3.2c. LSUN Cats images generated by the target generator on randomly sampled z-vectors.
Figure 3.2d. Corresponding LSUN Cats images generated by the source generator.
As discussed previously, we could also fine-tune the fused generator to a limited target dataset setting. In this case, the selector will choose among different domains by appropriately selecting learnable embeddings γ and β, and the constant inputs to the generator rather than choosing between multiple generators.
In Table 4.1, we compare quantitative metrics to evaluate image quality of generated images for both single and multi-gan distillation settings.
Note, before moving to the evaluation discussion, we would like to highlight that single generator distillation (w/ perceptual loss) and multi-generator distillation experiments were still under training at the time of submission of this report. Also, we could not train the multi-generator distillation experiment with perceptual loss due to lack of compute resources.
In the first 2 rows of the table, we demonstrate generalizability of our model. In particular, we compare PSNR and SSIM metrics for image generation from training latent vectors with images generated from random latent vectors. As we can see, FFHQ-test performs reasonably well when compared to FFHQ-train.
Comparing row 1 v/s row 3 and row 2 v/s row 4, we can clearly see that training with perceptual loss helps in generating better quality images. We expect the network to perform even better if we were to train it till convergence.
Further, in the last two rows, we also calculate image quality metrics for the multi-generator distillation setting. Note, while a single generator is trained to generate the two domains (FFHQ and LSUN_Cat), we evaluate them independently. Comparing rows 2 and 5, we see that the fused generator is able to capture most details from the source generator. Again, we expect the network to perform even better if we were to train it till convergence.
Due to lack of compute, we could not train a single generator distillation on LSUN_Cat, so there's no fair comparison available for the last row in the table. However, if we compare the reconstruction results for LSUN_Cat with that of FFHQ, it is significantly poorer. We suspect, this might be because of the higher diversity in the LSUN_Cats dataset.
Table 4.1 We compare image reconstruction metrics such as PSNR and SSIM for single and multi-gan distillation settings. Higher is better for both the metrics.
[1] Hinton et al, “Distilling the Knowledge in a Neural Network”, NeurIPS Deep Learning and Representation Learning Workshop (2015).
[2] Karras et al, “Analyzing and Improving the Image Quality of StyleGAN”, arXiv:1912.04958 (2020).
[3] Addepalli et al, “DeGAN: Data-Enriching GAN for Retrieving Representative Samples”, AAAI (2020).
[4] Kurmi et al, “Domain Impression: A Source Data Free Domain Adaptation Method”, WACV (2021).
[5] Kundu et al, “Universal Source-Free Domain Adaptation”, CVPR (2020).
[6] Kundu et al, “Towards Inheritable Models for Open-Set Domain Adaptation”, CVPR (2020).
[7] Wang et al, “Adversarial Learning of Portable Student Networks”, AAAI (2018).
[8] Chen et al, “Distilling Portable Generative Adversarial Networks for Image Translation”, AAAI (2020).
[9] Wang et al, “KDGAN: Knowledge Distillation with Generative Adversarial Networks”, NeuRIPS (2018).
[10] Chang et al, “TinyGAN: Distilling BigGAN for Conditional Image Generation”, ACCV (2020).
[11] Aguinaldo et al, “Compressing GANs using Knowledge Distillation”, arXiv:1902.00159 (2019).
[12] Isola et al, “Image-to-Image Translation with Conditional Adversarial Nets”, CVPR (2017).
[13] Lin et al, “Anycost GANs for Interactive Image Synthesis and Editing”, CVPR (2021).
[14] Sankaranarayanan et al, “Generate To Adapt: Aligning Domains using Generative Adversarial Networks”, CVPR (2018).
[15] Li et al, “GAN Compression: Efficient Architectures for Interactive Conditional GANs”, CVPR (2020).
[16] Li et al, "Semantic relation preserving knowledge distillation for image-to-image translation", ECCV (2020).
[17] Lopes et al, “Data-Free Knowledge Distillation for Deep Neural Networks”, NeurIPS Workshop on Learning with Limited Data (2017).
[18] Wang et al, "MineGAN: effective knowledge transfer from GANs to target domains with few images", CVPR (2020)