Gabriel Poesia

How Does Batch Normalization Help Optimization? (@ NeurIPS 2018)

Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, Aleksander Madry

Link

This paper justifies why BatchNorm improve neural network training so drastically. The reason given in the original paper was that it stabilizes the distribution of inputs to one layer when previous layers are updated (i.e., reduces internal covariate shift). This turns out to (1) not be the case - BatchNorm either makes it worse or unchanged, and (2) unrelated to the observed improvements. This paper shows these two facts in a simple experimental setting. Then, it goes on to understand why BatchNorm works (since it does). It turns out that BatchNorm makes both the loss and the gradients more stable: they both get reduced Lipschitz constants with BatchNorm compared to without. This means that the change in the gradient after taking one step during optimization is more limited; thus, the gradient becomes more stable and optimization easier. These are both shown empirically and then formally (it helps a lot that BatchNorm is a very simple change to analyze algebraically - it's just a shift and scale; thus, the main proofs are mostly algebraic and don't really rely on any heavy tools).

This is pretty cool and illuminating. It also goes to show that our intuitions about optimizing large models in high dimensions are most often wrong, even if they're expressed with fancy-sounding terminology (/internal covariate shift/) and tell a nice story.