This paper introduces a now very popular technique to train models using gradients through discrete choices. It’s an alternative to the REINFORCE estimate (which I knew, but the first actual derivation I saw was in this paper; it’s super simple). Suppose, for instance, you’re training a model to choose between two discrete actions: either turn left or right. Given the action, some environment gives you a reward. Gumbel-Softmax gives you a way to backpropagate that reward to a neural network that produces the action.

The idea is based on the Gumbel-Max trick. Let’s say we have a vector of integers: $v = [1, 2, 3]$. Say we want a way to draw a vector $v'$ of 3 integers, such that the probability of each index being the maximum of that vector is proportional to the value. A little more precisely, I want $p(\arg\max_i v'_i = 1) = \frac{1}{6}$, $p(\arg\max_i v'_i = 2) = \frac{2}{6}$, $p(\arg\max_i v'_i = 3) = \frac{3}{6}$. One natural idea one might have is to draw a vector of uniform samples in the range $[0, v_i]$. That’s almost right, but it’s a bit biased: the probability of the last element being the maximum is not exactly $\frac{1}{2}$ in this example, but actually a bit higher. The Gumbel-Max trick is a way to fix that: if you add a random noise drawn from the Gumbel distribution to each uniform sample, then the resulting distribution has exactly the property we want.

On top of this, Gumbel-Softmax simply uses a Softmax layer as a continuous relaxation of the $\max$ operation. In the forward pass, you use softmax, and then discretize the output. In the backward pass, you just pretend the gradient for the discrete output is exactly the gradient to the continuous softmax prediction, and follow on normally. In practice, the softmax predictions of the model will tend to be almost one-hot, so this makes sense and works well. Apparently, it also has a lower variance than REINFORCE. The paper has an experiment demonstrating that. It also seems important to anneal the Softmax temperature during training, i.e. to take the predictions as “less one-hot” in the beginning, and increasingly make them more “one-hot”.