GAN Tutorial
Last updated
Last updated
This topic describes what a Generative Adversarial Network (GAN) is and how to create one in PerceptiLabs.
A GAN is a machine learning model comprised of two neural networks that both feed into each other and compete against each other. In doing so, it learns how to both generate realistic data and classify data as being real or generated. The architecture of this model is shown in Figure 1:
The generator takes random noise as input and learns to generate real data. It does this by sending its output to the second neural network in the GAN, the discriminator, which gradually improves its ability to classify the generator's output from real data/training data. The generator in turn, receives feedback from the discriminator which helps it to learn how to generate increasingly realistic data. From an implementation standpoint, both the generator and discriminator have their own loss functions, where the loss function of the discriminator constitutes feedback from the generator. The discriminator's output is included in the loss functions, where the generator gets its weights updated through back propagation during training.
For additional information about GAN's, check out our blog: Exploring Generative Adversarial Networks (GANs).
A good starting point for experimenting with GANs is to build one which learns to generate and classify images of handwritten digits.
As Figure 1 above illustrates, the generator for such a GAN learns to generate increasingly realistic images while the discriminator increases its ability to classify those images as real or fake.
PerceptiLabs provides a template model that does exactly this, using the MNIST database of 28x28 pixel, handwritten grayscale digits, a copy of which is included with the template.
Follow the steps below to generate the model in PerceptiLabs:
Navigate to File> New.
Select GAN on the New Modelpopup.
Enter a Name.
(Optional) Specify a location for the model in the Model Path field.
Click Create. A new GAN model is created, complete with sample data, that is ready to train.
Note
Due to the fact that a GAN's neural networks feed into each other, training a GAN can take a long time, ranging from a few hours to a few days. This can vary based on the type of initialization used for the model's parameters and the learning rates. Long training times can also occur when both the generator and discriminator compete so much that they get "stuck" at certain levels for long periods.
The generated GAN model appears as follows in PerceptiLabs:
The GAN model can be logically broken down into the elements highlighted in Figure 2:
1. Random Data: Input starts with a Random noise component which is fed into the generator (2). This component is only used on the first epoch as a way to seed the generator.
2. Generator Layers: A dense neural network consisting of two layers representing the generator. The first layer contains 128 activations and the second contains the 784 activations (pixel values) for a generated image. The result is then transformed by a Reshape component into a 28x28 array representing the generated 2D grayscale image.
Note
The "size" of the noise is 100 but the expected size of the generated image should be 784. Thus the first layer, which is a dense layer of size 128, will help the generator in learning the useful features for generating realistic looking data.
3. Real Data: Input starts with the 28x28 (784) pixel grayscale images from the MNIST dataset which are transformed by a Reshape component into 28x28 array pixel values representing the (real) 2D grayscale images to use for comparison.
4. Switch Component: Used by the GAN component for switching between the generated and real data images during each epoch.
5. Discriminator Layers: Contains a dense neural network consisting of two layers representing the discriminator. The first layer contains 128 activations mapped to a second layer with one activation. This second layer's one neuron represents the classification of whether a given image coming in from the Switch component is real or generated.
6. GAN Component: The GAN component ties the whole model together and encapsulates the GAN algorithm. This includes invoking the Switch component during classification to take in either generated or real data, and feeding the output of the generator and discriminator between the two neural networks. You can view the logic by inspecting the GAN component's run()
method in the code viewer.