Training on Thin Air: Improve Image Classification with Generated Data

Yongchao Zhou, Hshmat Sahak, Jimmy Ba

[ ArXiv | GitHub]

Abstract

Acquiring high-quality data for training discriminative models is a crucial yet challenging aspect of building effective predictive systems. In this paper, we present Diffusion Inversion, a simple yet effective method that leverages the pre-trained generative model, Stable Diffusion, to generate diverse, high-quality training data for image classification. Our approach captures the original data distribution and ensures data coverage by inverting images to the latent space of Stable Diffusion, and generates diverse novel training images by conditioning the generative model on noisy versions of these vectors. We identify three key components that allow our generated images to successfully supplant the original dataset, leading to a 2-3x enhancement in sample complexity and a 6.5x decrease in sampling time. Furthermore, our approach consistently outperforms generic prompt-based steering methods and KNN retrieval baseline across a wide range of datasets, exhibiting especially remarkable results in specialized fields like medical imaging. Furthermore, we demonstrate the compatibility of our approach with widely-used data augmentation techniques, as well as the reliability of the generated data in supporting various neural architectures and enhancing few-shot learning performance.

Method

Stable Diffusion, a model trained on billions of image-text pairs, boasts a wealth of generalizable knowledge. To harness this knowledge for specific classification tasks, we propose a two-stage method that guides a pre-trained generator, G, towards the target domain dataset. In the first stage, we map each image to the model’s latent space, generating a dataset of latent embedding vectors. Then, we produce novel image variants by running the inverse diffusion process conditioned on perturbed versions of these vectors.

Results

We pinpoint three vital components that allow models trained on generated data to surpass those trained on real data: 1) a high-quality generative model, 2) a sufficiently large dataset size, and 3) a steering method that considers distribution shift and data coverage. Moreover, we show that our generated data is compatible with various standard data augmentation strategies and can enhance model performance across numerous popular neural architectures.

A High-quality Generator is Needed

Our method outperforms both GAN and GAN Inversion techniques when trained on datasets of equivalent size to the original real dataset, highlighting the significance of a high-quality pre-trained generator.

A Sufficiently Large Dataset is Crucial

The test accuracy of ResNet18 increases as more generated data is incorporated, eventually exceeding the performance of the model trained on the entire real dataset.

When comparing models trained on real vs generated images across datasets of equal sizes, our approach demonstrates substantially improved performance when the dataset  is small. In large-dataset scenarios, it achieves the same test accuracy using 2-3x less real data.

Distribution Shift and Coverage Matter

Diffusion Inversion (DI) consistently excels in handling distribution shifts and data coverage across three medical imaging datasets. DI outperforms generating new datasets using KNN retrieval from the LAION dataset.

Our method improves few-shot learning performance, yielding results similar to LECF. 

Our method demonstrates significantly better scalability than LECF on STL10 using only generated data.

Comparison against Data Augmentation

Data Augmentation Techniques on STL10: Our approach, combined with default augmentation (crop and flip), consistently outperforms alternatives and can be further improved by merging with standard data augmentation techniques.

Evaluation on Various Architectures 

The synthetic dataset significantly outperforms the real dataset across a range of diverse neural architectures on STL10.

Generated Images

Standard Datasets

CIFAR10

CIFAR100

STL10

ImageNette

Specialized Datasets

PathMNIST

BloodMNIST

DermaMNIST

EuroSAT

BibTeX

@article{zhou2023training,

      title={Training on Thin Air: Improve Image Classification with Generated Data}, 

      author={Yongchao Zhou and  Hshmat Sahak and Jimmy Ba},

      year={2023},

      eprint={2305.15316},

      archivePrefix={arXiv},

      primaryClass=[cs.CV]

}

Acknowledgement

We would like to thank Keiran Paster, Silviu Pitis, Harris Chan, Yangjun Ruan, Michael Zhang, Leo Lee, and Honghua Dong for their valuable feedback. Jimmy Ba was supported by NSERC Grant [2020-06904], CIFAR AI Chairs program, Google Research Scholar Program and Amazon Research Award. Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, and companies sponsoring the Vector Institute for Artificial Intelligence.