Generative Adversarial Networks – An Experiment with Training Improvement Techniques (in Tensorflow)

FavoriteLoadingAdd to favorites

Well first, if you’re interested in Deep Learning but just don’t bother reading this post at all, I recommend you to take a look at the deep learning nano degree offered by Udacity. It looks like a well designed series of courses that covers all major aspects of deep learning.

Introduction

Generative Adversarial Networks (GANs) have become one of the most popular advancements in deep learning. The method is unique in that it has a generative model that generates synthetic data from noise, while having another discriminative model that tries to distinguish the synthetic data from real ones. For example, a well trained GAN for image classification could generate images that look way more realistic than thoses generated by other deep learning models. In addition to image recognition, generative models are also useful for tasks such as predicting rare events, as GANs could be used to increase the target sample size so that a predictive model could yield better performance. This post is not intended to serve as an introduction to GANs as there are many great articles covering it. For instance, this article uses a simple example to demonstrate the implementation of GANs in Tensorflow. The objective is to train a model that learns to generate a Gaussian distribution like this:

As a vanilla GAN model is hard to train, the article explored ways to improve training, one highlight being minibatch training (the article has it well explained). Since the article was published, there has been further development in training techniques of GANs. Therefore, I took a look at a couple of techiniques and applied them to this simple example to see how the new models perform. All the new code are written based on the original code here.

Orinigal Models

The original article (which you’re recommended to read first) showed examples of generated distributions by models with or without minibatch technique. I re-ran the two model training processes on my laptop:

No minibatch

Minibatch

The results look slight different from the article. Minibatch is supposed to make the model better at generating a similar distribution but it didn’t work quite well as intended.

Adding Noises

As explained here and here, Adding Gaussian noises (with zero mean and tiny variance) to the input data of the discriminative network, i.e. the synthetic data points generated by the generative model and data points sampled from the real Gaussian distribution, could force the generator output and the real distribution to spread out so that to create more overlaps, which makes it easier for training. I tweaked the original code so now the class DataDistribution could be used to not only sample data from the target distribution, but also sample noises by setting mu = 0 and sigma = 0.001 (or some other small numbers):


class DataDistribution(object):
    def __init__(self, mu = 4, sigma = 0.5):
        self.mu = mu
        self.sigma = sigma

    def sample(self, N):
        samples = np.random.normal(self.mu, self.sigma, N)
        samples.sort()
        return samples

In train method, we can now add noises to the input of the discriminators:


for step in range(params.num_steps + 1):
        # update discriminator
        x = data.sample(params.batch_size)
        z = gen.sample(params.batch_size)
        # Sample noise
        n_x = noise.sample(params.batch_size)
        n_z = noise.sample(params.batch_size)
        loss_d, _, = session.run([model.loss_d, model.opt_d], {
                model.x: np.reshape(x + n_x, (params.batch_size, 1)),
                model.z: np.reshape(z + n_z, (params.batch_size, 1))
        })

The results are as follows:

No minibatch, added noise (std = 0.001)

Minibatch, added noise (std = 0.001)

The model without minibatch is able to mimic the bell shape pretty well, but do notice that it also leaves a long tail to the left. The training loss of the generator actually increased from the first example. The minibatch model does look to have improved a lot from the first example, where the output distribution is much less centered around mean now.

Feature Matching

This post explained pretty well how feature matching works in training GANs. The basic idea is that, instead of just using the activation layer of the discriminator to minimizating the loss of the generator, it uses information from the hidden layer together with the activation layer for better optimization. To implement this, we need to expose a hidden layer (h2) of the discriminator:


def discriminator(input, h_dim, minibatch_layer=True):
    h0 = tf.nn.relu(linear(input, h_dim * 2, 'd0'))
    h1 = tf.nn.relu(linear(h0, h_dim * 2, 'd1'))
    print("h0:{}".format(h0.shape))
    print("h1:{}".format(h1.shape))
    # without the minibatch layer, the discriminator needs an additional layer
    # to have enough capacity to separate the two distributions correctly
    if minibatch_layer:
        h2 = minibatch(h1)
    else:
        h2 = tf.nn.relu(linear(h1, h_dim * 2, scope='d2'))

    h3 = tf.sigmoid(linear(h2, 1, scope='d3'))
    print("h3:{}".format(h3.shape))
    return h3, h2

h2 will be feeded into the generator’s loss function:


# Original loss function: self.loss_g = tf.reduce_mean(-log(self.D2))
self.loss_g = tf.sqrt(tf.reduce_sum(tf.pow(self.D1_h2 - self.D2_h2, 2)))

Where D1_h2 and D2_h2 are two h2 layers from the discriminator that takes in generator’s data and real samples respectively. Here are the results:

No minibatch, added noise (std = 0.001), feature matching

Minibatch, added noise (std = 0.001), feature matching

The model without minibatch improved from the last attempt as you can tell the fat tail has disappeared, though it may not be an apparent improvement on the vanilla method. In contrast, the model with minibatch and added noise did not perform well.

Conclusion

The experiments yielded mixed results, but this really is just a toy project. The updated code can be found here. If you’re interested in learning more about GANs, the linked articles in the post are all really good starting point provided you have prior knowledge in traditional deep networks:

GANs introduction and example: http://http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/
Improvement techniques: http://http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/
Adding Noises: http://http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/
Feature matching: http://http://blog.aylien.com/introduction-generative-adversarial-networks-code-tensorflow/

Leave a Reply

Your email address will not be published. Required fields are marked *