GAN — Wasserstein GAN & WGAN-GP
Training GAN is known to be hard. We deal with the performance, model stability, model convergence and mode collapse issue regularly. To move forward, we can study incremental improvements on GAN or embrace a new path on how the cost function is determined. GAN minimizes the Jensen-Shannon divergence. Many examinations are done on its gradient behavior to explain issues on model convergence and mode collapse. Do cost functions matter in GAN training? This article is part of the GAN series which looks into the Wasserstein GAN (WGAN) and the WGAN-Gradient penalty in details. The equation for Wasserstein GAN looks very unapproachable. We will simplify it with examples. It looks complicated but it is not!
Earth-Mover (EM) distance/ Wasserstein Metric
Let’s complete a simple exercise on moving boxes. We get 6 boxes and we want to move them from the left to the locations marked by the dotted square on the right. For box #1, we move it from location 1 to location 7. The cost is equal to its weight times the distance. For simplicity, we will set the weight to be 1. Therefore the cost to move box #1 is equal to 7–1=6.
The following diagram demonstrates two different ways of moving the boxes. The tables in the right report how many boxes are moved from one position to another. For example, in the first plan, the entry γ(1, 10) equals to two: we move 2 boxes from the location 1 to the location 10. We call γ the transport plan. The total transport cost of either plan is 42.
Not all transport plans have the same cost. The plans below have different costs. The Wasserstein distance (or the EM distance) is the cost of the cheapest transport plan. In the example below, the Wasserstein distance is two.
Let’s throw in some complicated terms before explaining it. In processing probability distributions, the Wasserstein distance is the minimum cost of transporting mass to convert the data distribution q into the data distribution p. The Wasserstein distance for the real data distribution Pr and the generated data distribution Pg is mathematically defined as the greatest lower bound (infimum) for different transport plans:
From the WGAN paper:
Π(Pr, Pg) denotes the set of all joint distributions γ(x, y) whose marginals are respectively Pr and Pg.
Don’t get scared by the mathematical formula. The equation above is the equivalent of our example in the continuous space instead of the discrete space. Π contains all the possible transport plan γ. We combine variable x and y to form a joint distribution γ(x, y). For example, in the first table below, γ(1, 10) equals 2 / 6, the chance that the box in location 10 is originated from location 1 (γ(x=1, y=10)). The number of boxes in location 9 must originally come from any position, i.e. ∑ γ(*, 9) = 2. That is the same as saying γ(x, y) must have marginals Pr and Pg respectively.
KL-Divergence and JS-Divergence
Before advocating any new cost functions, let’s look at the two common divergences used in GAN first, namely the KL-Divergence and the JS-Divergence.
where p is the real data distribution and q is the one estimated from the model. Let’s assume they are Gaussian distributed with different means but variance equals to one. In the diagram below, we plot p and a few q having different means.
Below, we plot the corresponding KL-divergence and JS-divergence between p and q with the mean of q ranging from 0 to 35. As anticipated, when both p and q are the same, the divergence is 0. As the mean of q increases, the divergence increases. The gradient will eventually diminish. We have close to a zero gradient for the generator to learn.
Criticizing is easy. In practice, GAN can optimize the discriminator easier than the generator. Minimizing the GAN objective function with an optimal discriminator is equivalent to minimizing the JS-divergence. As illustrated above, GAN barely learns anything if the generator is not performed well (p not close to q). (For proof that GAN is related to the JS-divergence, please refer to “Why it is so hard to train Generative Advisory Networks!”).
What is wrong with the GAN objective function? Arjovsky et al 2017 wrote a paper to illustrate the GAN problem mathematically. It examines the gradient behavior for both generator’s cost function proposed by the original GAN. In summary,
- GAN has saturation issue for the first proposed objective function. An optimal discriminator produces good information for the generator to improve. But if the generator is not doing a good job yet, the gradient for the generator diminishes and the generator learns nothing.
- The second proposed objective function has a large variance of gradients that hurt the model stability and convergence.
- Theoretically, it happens even a slight misalignment between the ground truth and the model, and
- During training, adding noise to generated images can stabilize the model.
For more details, here is another article that summarizes some of the important mathematical claims.
Wasserstein Distance
Wasserstein GAN (WGAN) approaches this problem by having a cost function using Wasserstein distance that has a smoother gradient. WGAN learns no matter the generator is performing or not. The diagram below repeats a similar plot on the value of D(X) for both GAN and WGAN. For GAN (the red line), it fills with areas with diminishing or exploding gradients. For WGAN (the blue line), the gradient is smoother everywhere and therefore learns better even the generator is not producing good images.
Wasserstein GAN
The equation for the Wasserstein distance is highly intractable. Using the Kantorovich-Rubinstein duality, we can calculate the distance by
where sup is the least upper bound and f is a 1-Lipschitz function following this constraint:
One question is how we compute f. In fact, we just learn it exactly the same way as the discriminator D in GAN except f outputs a score, a scalar value rather than a probability. The score measures how real the input images are. In reinforcement learning, we call it the value function which measures how good a state (the input) is. We rename the discriminator to critic to reflect its new role. Let’s show GAN and WGAN side-by-side and you will realize that you know how to implement it already.
GAN:
WGAN
The model design is almost the same which the deep network outputs a score instead of a probability, and we will use a slightly different equations for the cost function.
The cost function for the WGAN’s critic is:
and the gradient to train the generator becomes
Even that, the difference is very small if we compare them side-by-side again.
However, there is one major thing missing. f has to be a 1-Lipschitz function. To enforce the constraint, WGAN applies a very simple clipping to restrict the maximum weight value in f.
Algorithm
Now we can put everything together in the pseudo code below.
Experiment
Correlation between loss metric and image quality
In GAN, the loss measures how well it fools the discriminator rather than a measure of the image quality. As shown below, the generator loss in GAN does not drop even the image quality improves. Hence, we cannot tell the progress from its value. We need to save the testing images and evaluate it visually. On the other hand, WGAN loss function reflects the image quality which is more desirable.
Improve training stability
Two significant contributions for WGAN are
- it has no sign of mode collapse in experiments, and
- the generator can still learn when the critics is well trained.
WGAN — Issues
Lipschitz constraint
Clipping allows us to enforce the Lipschitz constraint on the critic’s model to calculate the Wasserstein distance.
Quote from the research paper: Weight clipping is a clearly terrible way to enforce a Lipschitz constraint. If the clipping parameter is large, then it can take a long time for any weights to reach their limit, thereby making it harder to train the critic till optimality. If the clipping is small, this can easily lead to vanishing gradients when the number of layers is big, or batch normalization is not used (such as in RNNs) … and we stuck with weight clipping due to its simplicity and already good performance.
While Wasserstein distance looks promising, the difficulty is calculating its value or enforcing the Lipschitz constraint. Clipping is simple but it introduces some new problems. Indeed, the model may still produce poor quality images and does not converge. One of the challenge for WGAN is choosing the hyperparameter c for the clipping.
The model performance is very sensitive to this hyperparameter. In the digram below, when batch normalization is off, the discriminator moves from diminishing gradients to exploding gradients when c increases from 0.01 to 0.1.
Model capacity
The weight clipping also behaves as a weight regulation. It reduces the capacity of the model f as we clip the weight of the model. The reduced capacity limits the capability to model complex functions. In the following experiment, we want to model the ground truth data having multiple modes (one with an 8-mixture Gaussian distribution, one with 25 and the last one with the Swiss Roll distribution). The contour plot is the value function estimated by f and the orange dots are the modes. In the second row, we used the WGAN-Gradient penalty (discussed next) to estimate the value function. The mode (orange dots) should be the local maximum of the value function and therefore surround by the contour plot. (The mode is just like a peak in the topographic map.) The first row is created by the WGAN. In this experiment, WGAN loss the capability to model a complex function that is expected by the ground truth.
Wasserstein GAN with gradient penalty (WGAN-GP)
Now we are introducing WGAN-PN which uses gradient penalty instead of weight clipping to enforce the Lipschitz constraint.
Gradient penalty
WGAN-GP uses the constraint that a differentiable function is 1-Lipschtiz if and only if it has gradients with norm at most 1 everywhere.
Here, we add another term to to penalize the critic if the norm of the gradient is other than 1.
A sample point is drawn between a real image and a generated image to compute the gradient used for the gradient penalty. λ is set to 10. Batch normalization is avoided in this model because the correlations between samples in the mini-batch will not compatible with the gradient penalty sampling that should be computed independently.
Algorithm
Let’s look into the pseudo code in detailing how the sample point is created and how the gradient penalty is computed.
WGAN-GP Experiments
WGAN-GP enhances training stability. As shown below, when the model design is less optimal, WGAN-GP can still create good results while the original GAN cost function will fail.
Here is the inception score using different methods. The experiment from the WGAN-GP paper demonstrates better image quality and convergence comparing with WGAN. However, DCGAN demonstrates slightly better image quality and it converges faster. But the inception score for WGAN-GP is more stable when it starts converging.
So what is the benefit of WGAN-GP if it cannot beat DCGAN? The major advantage of WGAN-GP is its convergency. It makes training more stable and therefore easier to train. Earlier, we tend not to use complex models because it is already too hard to train and this is the likely reason that we do not try more complicated model design or more layers. As WGAN-GP helps the model to converge better, we can use a more complex model like a deep ResNet for the generator and/or discriminator. The following are the inception score using ResNet with WGAN-GP.
This the bedroom generated with 128x128 resolutions which produce higher resolution image comparing with earlier research.
More thoughts
The mathematical model produces a good framework for discussion. But there are many questions unanswered. The most important one, if the model paints such a negative picture, why GAN still produces reasonable results. In fact, a paper in early 2018 claims the “state-of-the-art” cost functions have no difference in performance if it trains better and tuned more vigorously.
Does cost function matter? Is the mathematical model too simple? Or there are issues to be solved before we can see the full benefits. WGAN and WGAN-GP do not produce the lock-out punch in producing superior image quality. The debate will continue. However, training GAN is not easy. The stability and mode collapse issue is real. Mathematically, the GAN’s cost function is suspicious even though some empirical results shows otherwise. Hopefully, we can present the claims in WGAN so you can explore further.
Reference
WGAN-GP: Improved Training of Wasserstein GANs
Towards principled methods for training Generative Adversarial Networks