Generative Adversarial Network GAN
Generative adversarial network (GAN)
This is for generating images that look similar to real images.
It has two sub networks, generator and discriminator. Below is how it works in an adversarial way.
A generator network generates a fake image from random noise (or word embeddings to describe the image?)
A real image is sampled and fed into discriminator network.
A discriminator network compares the real image to fake image and classifies the generated image as 'real' or 'fake'.
Two training processes for discriminator and generator separately.
Firstly training the discriminator with real images and the generated fake images. The loss function obviously measures the discrepency between the real and fake images.
After one or few epochs of training, the discriminator learns about how to tell the difference between real and fakes images.
Initially this is simple, because the generator is generating random/bad images.
After training the discriminator for a bit, it turns to train the generator.
Here it uses another loss function for the generator. The more mistakes from discriminator the better, i.e. the inverse of the discriminator's accuracy. That means the generated images looks real so as to fool the discriminator.
A few epochs of training of the generator improves the generator further.
Now the generator is better, it then takes turn to train the discriminator again.
This will further improve the discriminator.
Alternatively, it trains the discriminator and the generator, until a point when discriminator can't tell a fake image from teal images.
The discriminator accuracy tends to be 50% (flipping a coin) at the end if trained well.
By now, the generator can generates images look very similar to the real images.
Illustration of the GAN network.
Two subnets with their own loss functions. The real images is the training dataset for the generator to learn about.
Random noise (here) or description text (embeddings / representation from language models) is the input to generator.
Discriminator training
The real and generated images feed into the Discriminator network, with a Discriminator loss function to measure the discrepancy.
Only involves the discriminator network.
The back propagation goes to the discriminator network only.
Generator training
Use the Generator loss function to measure the inverse of discriminator's accuracy, the lower the better.
the back propagation goes to the generator network, while keeping the discriminator network's weights unchanged.
Don't update the weights of the discriminator at the same time.