Understanding Generative Adversarial Networks

Use A Simple GAN To Generate MNIST images

Generative Adversarial Networks

November 3, 2017

Understanding Generative Adversarial Networks: Use A Simple GAN To Generate MNIST images

If you see the above image that does not make much sense, this article is for you. I explain how GANs (Generative Adversarial Networks) work using a simple project that generates hand-written digit images.

I used Keras on TensorFlow, and the notebook code is available on my Github.

1 Background

GAN (Generative Adversarial Network) is a framework proposed by Ian Goodfellow, Yoshua Bengio, and others in 2014.

We can train a GAN to generate images from random noises. For example, we can train a GAN to generate digit images that look like handwritten ones from MNIST database.

MNIST hand-written digit images

A GAN has two parts: the generator that generates images and the discriminator that classifies real and fake images.

2 The Generator

The input to the generator is a series of randomly generated numbers called latent samples. Once trained, the generator can produce digit images from latent samples.

Our generator is a simple fully-connected network that takes a latent sample (100 randomly generated numbers) and produces 784 data points which we can reshape into a 28 x 28 digit image which is the size used by all MNIST digit images.

generator = Sequential([
    Dense(128, input_shape=(100,)),
], name='generator')

The summary output is as follows:

Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 128)               12928     
leaky_re_lu_1 (LeakyReLU)    (None, 128)               0         
dense_2 (Dense)              (None, 784)               101136    
activation_1 (Activation)    (None, 784)               0         
Total params: 114,064
Trainable params: 114,064
Non-trainable params: 0

We use the tanh activation, which How to Train a GAN? Tips and tricks to make GANs work recommends. It also means we must rescale the MNIST images between -1 and 1. More details are on my Github.

3 How to Train the Generator?

Without training, the generator produces garbage images only.

To train the generator, we need to train a GAN. Before talking about GAN, we shall discuss the discriminator.

4 The Discriminator

The discriminator is a classifier trained using supervised learning. It classifies whether an image is real (1) or not (0).

We train the discriminator using both the MNIST images and the generated images.

If the input image is from the MNIST database, the discriminator should classify it as real.

If the input image is from the generator, the discriminator should classify it as fake.

The discriminator is also a simple fully-connected neural network.

discriminator = Sequential([
    Dense(128, input_shape=(784,)),
], name='discriminator')

The summary output is as follows:

Layer (type)                 Output Shape              Param #   
dense_3 (Dense)              (None, 128)               100480    
leaky_re_lu_2 (LeakyReLU)    (None, 128)               0         
dense_4 (Dense)              (None, 1)                 129       
activation_2 (Activation)    (None, 1)                 0         
Total params: 100,609
Trainable params: 100,609
Non-trainable params: 0

The last activation is sigmoid to tell us the probability of whether the input image is real or not. So, the output can be any value between 0 and 1.

5 The GAN

We connect the generator and the discriminator to produce a GAN.

Generative Adversarial Networks

Keras has an easy way to connect two models as follows:

gan = Sequential([


The structure of the network is shown below:

Layer (type)                 Output Shape              Param #   
generator (Sequential)       (None, 784)               114064    
discriminator (Sequential)   (None, 1)                 100609    
Total params: 214,673
Trainable params: 214,673
Non-trainable params: 0

Now that we know the generator, the discriminator, and the GAN, we shall discuss how to train the generator.

6 Training the GAN means Training the Generator

When we feed a latent sample to the GAN, the generator internally produces a digit image, which the discriminator receives for classification. If the generator does a good job, the discriminator returns a value close to 1 (high probability of the image being real).

We feed latent samples to the GAN while setting the expected outcome (label) to 1 (real) as we expect the generator to produce a realistic image and the discriminator to say it is real or close to real.

However, the generator initially produces garbage images, and the loss value is high. So, the back-propagation updates the generator’s weights to produce more realistic images as the training continues, which is how we train the generator via training the GAN.

There is one catch in training the generator via the GAN. We do not want the discriminator’s weights to be affected because we are using the discriminator as merely a classifier.

For this reason, we set the discriminator non-trainable during the generator training.

Generative Adversarial Networks

7 The Train Loop

Let’s not forget that we must train the discriminator to do an excellent job as a classifier of real and fake images.

We train the discriminator and the generator in turn in a loop as follows:

Step 1) Set the discriminator trainable

Step 2) Train the discriminator with the real MNIST digit images and the images generated by the generator to classify the real and fake images.

The Discriminator Training

Step 3) Set the discriminator non-trainable

Step 4) Train the generator as part of the GAN. We feed latent samples into the GAN, let the generator produce digit images, and use the discriminator to classify the image.

The Generator Training (via GAN)

Ideally, The loop should continue until they are trained well and can not be improved any further.

8 But does it work?

The result of the simple GAN is not outstanding. Some of them look pretty good, but others are not.

As it turns out, training a GAN requires many hacks as per How to Train a GAN? Tips and tricks to make GANs work, such as label smoothing and other techniques.

There are all sorts of empirical quirks. If I train the discriminator much faster than the generator, the generator gives up learning. In some cases, the generator learns to deceive the discriminator and makes the discriminator unable to learn to classify correctly.

I tried different hacks, and the below plot of the loss values is what I could achieve after about one-day of experiments (complete details on my Github).

The generator should have lower loss values than the above. I believe we can improve the performance if we use more complex networks like DCGAN (Deep Convolutional GAN).

This article uses the simple GAN example to show how GAN works in principle. Once you know, it should be easier to understand other GAN articles and implementations.

Moreover, there are many kinds of GANs (the whole list), and people are inventing new types of GANs as we speak. So, GANs do work, and many people are researching GANs.

9 References