Manifold Graph with Learned Prototypes for Semi-Supervised Image Classification

† Georgia Tech, ‡ Virginia Tech

GitHub

Abstract

Recent advances in semi-supervised learning methods rely on estimating categories for unlabeled data using a model trained on the labeled data (pseudo-labeling) and using the unlabeled data for various consistency-based regularization. In this work,we propose to additionally explicitly leverage the structure of the data manifold based on a Manifold Graph constructed over the image instances within the feature space. Specifically, we propose an architecture based on graph networks that jointly optimizes feature extraction, graph connectivity, and feature propagation and aggregation to unlabeled data in an end-to-end manner. Further, we present a novel Prototype Generator for producing a diverse set of prototypes that compactly represent each category, which supports feature propagation. To evaluate our method, we first contribute a strong baseline that combines two consistency-based regularizers that already achieves state-of-the-art results especially with fewer labels. We then show that when combined with these regularizers, the proposed method facilitates the propagation of information from generated prototypes to image data to further improve results. We provide extensive qualitative and quantitative experimental results on semi-supervised benchmarks demonstrating the improvements arising from our design and show that our method achieves state-of-the-art performance when compared with existing methods using a single model and comparable with ensemble methods. Specifically, we achieve error rates of 3.35% on SVHN, 8.27% on CIFAR-10, and 33.83% on CIFAR-100. With much fewer labels, we surpass the state of the arts by significant margins of 41% relative error decrease on average.

Proposed Concept

The main idea of our work is to learn a set of class-specific prototypes during training that can compactly represent the images in the data manifold. To classify an image, we construct a Manifold Graph that leverages the structure of the data manifold to propagate and aggregate feature information to unlabeled data. The feature of the test image is thus refined to have a better representation that improves classification results.

Quantitative Results

Comparison with SoA single-model methods

We compare our implemented VAT and Π-VAT baseline methods as well as our proposed Manifold Graph method against representative single-model SSL methods on the SVHN, CIFAIR-10, and CIFAR-100 benchmark datasets (in error rate percentage, averaged over 3 runs with standard deviations)

Less labels on SVHN and CIFAR-10

The Π-VAT baseline that we developed is very strong. Our proposed Manifold Graph method with learned prototypes further improves this strong baseline consistently.

SVHN

CIFAR-10

Qualitative Results

t-SNE visualization of the data manifold including generated prototypes as well as the entire validation set. The stars represent prototypes, and the circles represent image features. Same color represent the same class. We can see that the generated prototypes align well with the clusters of image features, while retaining a good balance of diversity

Learned adjacency matrix by the Manifold Graph. Deeper colors represent stronger edge (higher edge weight). It learns to build a stronger edge between nodes within the same class, which represents a sparser, and locally connected graph in the data manifold.

Resources

GitHub

@article{kuo2019manifold,
  title   = {Manifold Graph with Learned Prototypes for Semi-Supervised Image Classification},
  author  = {Chia-Wen Kuo and Chih-Yao Ma and Jia-Bin Huang and Zsolt Kira}, 
  journal = {arXiv preprint arXiv:1906.05202},
  year    = {2019}
}