Categorical Reparameterization with Gumbel-Softmax (@ ICLR 2017)
Eric Jang, Shixiang Gu, Ben Poole
This paper introduces a now very popular technique to train models
using gradients through discrete choices. It's an alternative
to the REINFORCE estimate (which I knew, but the first actual derivation
I saw was in this paper; it's super simple).
Suppose, for instance, you're training a model to choose between two
discrete actions: either turn left or right. Given the action, some
environment gives you a reward. Gumbel-Softmax gives you a way to
backpropagate that reward to a neural network that produces the action.
The idea is based on the Gumbel-Max trick. Let's say we
have a vector of integers: . Say we want a way to draw
a vector of 3 integers, such that the probability of each index
being the maximum of that vector is proportional to the value.
A little more precisely, I want
,
,
.
One natural idea one might have is to draw a vector of uniform samples
in the range . That's almost right, but it's a bit biased:
the probability of the last element being the maximum is not exactly
in this example, but actually a bit higher.
The Gumbel-Max trick is a way to fix that: if you add a random
noise drawn from the Gumbel distribution to each uniform sample,
then the resulting distribution has exactly the property we want.
On top of this, Gumbel-Softmax simply uses a Softmax layer as a continuous
relaxation of the operation. In the forward pass, you use softmax,
and then discretize the output. In the backward pass, you just pretend
the gradient for the discrete output is exactly the gradient to the continuous
softmax prediction, and follow on normally. In practice, the softmax
predictions of the model will tend to be almost one-hot, so this makes sense
and works well. Apparently, it also has a lower variance than REINFORCE.
The paper has an experiment demonstrating that. It also seems important
to anneal the Softmax temperature during training, i.e. to take the predictions
as "less one-hot" in the beginning, and increasingly make them more "one-hot".