November 2nd, 2017

mixup: Data-Dependent Data Augmentation

By popular demand, here is my post on mixup, a new data augmentation scheme that was shown to improve generalization and stabilize GAN performance.

I have to say I have not seen this paper before people on twitter suggested I should write a post about this - which was yesterday. So these are all very fresh thoughts, warranty not included.

Summary of this post

Mixup

Let's jump right in the middle, here is how the mixup training loss is defined:

L(θ)=Ex1,y1ptrainEx2,y2ptrainEλβ(0.1)(λx1+(1λ)x2,λy1+(1λ)y2)

Very simply, we take pairs of datapoints (x1,y1) and (x2,y2), then choose a random mixing proportion λ from a Beta distribution, and create an artificial training example (λx1+(1λ)x2,λy1+(1λ)y2). We train the network by minimizing the loss on mixed-up datapoints like this. This is all.

One intuition behind this is that by linearly interpolating between datapoints, we incentivize the network to act smoothly and kind of interpolate nicely between datapoints - without sharp transitions.

Reformulation

Now, let's assume the loss is linear in it's second argument, such that

(x,py1+(1p)y2)=p(x,y1)+(1p)(x,y2)

This is the case in classification, where the loss is the binary cross entropy (x,y)=ylogp(x;θ)(1y)log(1p(x;θ)). It also works the same way for one-hot-encoded categorical labels.

In these cases, we can rewrite the mixup objective as

Ex1,y1pDEx2,y2pDEλβ(α,α)(λx1+(1λ)x2,λy1+(1λ)y2)=Ex1,y1pDEx2,y2pDEλβ(α,α)λ(λx1+(1λ)x2,y1)+(1λ)(λx1+(1λ)x2,y2)=Ex1,y1pDEx2,y2pDEλβ(α,α)EzBer(λ)z(λx1+(1λ)x2,y1)+(1z)(λx1+(1λ)x2,y2)=Ex1,y1pDEx2,y2pDEzBer(0.5)Eλβ(α+1z,α+z)z(λx1+(1λ)x2,y1)+(1z)(λx1+(1λ)x2,y2)=12Ex1,y1pDEx2,y2pDEλβ(α,α+1)(λx1+(1λ)x2,y1)+12Ex1,y1pDEx2,y2pDEλβ(α+1,α)(λx1+(1λ)x2,y2)=Ex,ypDExpDEλβ(α,α+1)(λx+(1λ)x,y)

Line by line, I used the following tricks:

  1. linearity of the loss, as in assumption above
  2. expectation of a Bernoulli(λ) variable plus linearity of expectation
  3. Bayes' rule p(z|λ)p(λ)=p(λ|z)p(z) and the fact that the Beta distribution is conjugate prior for the Bernoulli
  4. expectation of a Bernoulli(0.5) plus linearity of expectation
  5. symmetry of the Beta distribution in the sense that λBeta(a,b) implies 1λBeta(b,a), plus changing variable names in the expectation so the two terms become the same

So, here is what we are left with:

L(θ)=E(x,y)pDEλβ(α,α+1)E(x)pD(λx+(1λ)x,y)

I think this formulation is much nicer, because:

L(θ)=Ex,yplabelledEλβ(α,α+1)Expunlabelled(λx+(1λ)x,y)

Pytorch code

Here's how you'd modify the pytorch code from the paper to make this work (I draw different lambdas per datapoint in the minibatch, they draw one lambda per minibatch, not sure which one works better):

for (x1, y1), (x2, _) in zip(loader1, loader2):  
    lam = numpy.random.beta(alpha+1, alpha, batchsize)
    x = Variable(lam * x1 + (1. - lam) * x2)
    y = Variable(y1)
    optimizer.zero_grad()
    loss(net(x), y).backward()
    optimizer.step()

Only three changes:

This should do roughly the same thing.

Let's visualize this

Let's look at what the this data-dependent augmentation looks like for a single datapoint on the two-moons dataset:

The white and black crosses are positive and negative examples respectively. The mixup data augmentation doesn't care about the labels, just the distribution of the data. I applied the mixup to the datapoint at roughly x=(0.7,0.6). The transparent blue dots show random samples from the vicinal distribution. Each blue point was obtained by picking another training datapoint x at random, then picking a random λ from a Beta(0.1,1.1) and then interpolating between x and x accordingly. Note that Beta(0.1,1.1) is not symmetric, roughly 80% of sampled λ will be higher than 0.9 sowe tend to end up much closer to x than x. All these blue dots would be added to the training data with a 'white' label.

This is what the picture looks like when we applied data augmentation to two training examples, one positive, and one negative, and using 10k unlabelled samples (I am plotting a few of those unlabelled samples for reference):

You can see that the vicinal distribution of mixup does not really follow the manifold - which one would hope a good semi-supervised data augmentation scheme would do in this particular example. It does something rather weird. Things don't look that much cleaner whe we have more labelled examples:

Finally, with all datapoints labelled:

Why should data augmentation generalize better?

Why should a data augmentation scheme - or vicinal risk minimization (VRM) generalize better in classification? Generalization gap is about the difference between training and validation losses which is there because the training and test distributions differ (in practice, these are both empirical distributions concentrated on a number of samples). Vicinal risk minimization (VRM) or data augmentation trains by minimizing risk on the augmented training distribution:

paug(x,y)=1Ntrainx,yDtrainν(x,y|x,y),

for some vicinal distribution ν(x,y|x,y). This should be successful if the augmented distribution paug actually ends up closer to ptest than the original training ptrain. Following the ideas of (Lacoste-Julien et al, 2011), also or section 1.3.7 of my PhD thesis, we can even find the divergence measure we should use to measure the difference between these distributions, based on Bayes decision theory. If we use the binary cross-entropy both for training and testing, the Bayes optimal decision function on the augmented data will be (assuming perfectly balanced classes):

qaug(x)=paug(x|1)paug(x|1)+paug(x|0)

When mixup mixes up the empirical distribution of the two classes, it turns them into continuous distributions with perfectly overlapping support. Therefore, the Bayes optimal decision function is more or less unique, so we should be able to find qaug(x), or something very close to it quite consistently during training.

The loss of this optimal training classifier qaug(x) on the test data is as follows:

Ltest(qaug)=EyptestExptest(x|y)log(paug(x|y))+Exptestlog(paug(x))=EyptestKL[ptest(x|y)paug(x|y)]KL[ptest(x)paug(x)]

Interestingly, if we use ptest(x) in the mixup, we an actually evaluate these quantities exactly, which gives us a lower bound on the training error.

So, let's answer the question: why might data augmentation generalize better:

  1. because the training loss becomes better defined such that the Bayes-optimal solution is unique and easier to find consistently. In the case of mixup this happens because to the fact that the class-conditional distributions end up having fully overlapping support.
  2. because data augmentation turns the training distribution into a distribution that is closer to the test distribution. The ideal generalization gap can be seen as Bregman divergence between ptest and paug, up to a constant. In the case of mixup, we can actually calculate this divergence if we use ptrain(x) in the vicinal distribution.

Why should it work well for GANs?

I suspect reason number 1 above also explains why mixup works well for training GANs. One of the issues with the usual GAN setup is that the training and synthetic distributions have widely different support, and are often concentrated to lower-dimensional manifolds.

Instance noise was initially introduced to alleviate this problem by making the support of the two distributions overlap. See also (Arjovsky and Buttou, 2017). Mixup achieves the same thing, and I would imagine it does so even better in practice.

Why should it work well against adversarial examples?

Resilience to adversarial examples is somewhat related to, but crucially different from generalization. Firstly, the test data in the adversarial setting ptest is generated specifically to fool our classification function - it's no longer independent of the training set or the training algorithm. Secondly, and crucially, adversarial examples are created on the basis of the gradient of the decision function. However, the training loss only depends on values of the decision function at certain training points. For a decision function to achieve minimal training loss, it doesn't even have to be differentiable, or continuous. In fact, we could even implement it as a look-up-table, which would be memorization. So adversarial examples are created by a property of the decision function that the empirical loss is perfectly insensitive to.

The way mixups addresses this is twofold:

  1. due to the vicinal distribution, the training loss function suddenly starts to care about a local neighbourhood of the decision function around a training datapoint - and the behaviour in that local neighbourhood will depend on the gradients.
  2. It is evident from the figures above that mixup tends to produce data augmentations to cover the likely location of where adversarial examples might end up.

Summary

Mixup is an interesting technique, but I don't believe it's the final version that we will see in the wild. I can imagine more interesting versions that somehow combine this idea with nearest-neighbour kind of ideas, which work quite well for manifold-like data. For example, one could restrict the pairs considered for mixing to pairs of datapoints within a certain distance. This might not be as effective at countering adversarial examples, but may work much better for generalization.

Sumo