Gabriel Poesia

Prioritized Training on Points that are Learnable, Worth Learning, and Not Yet Learnt (@ ICML 2022)

Sören Mindermann, Jan Brauner, Muhammed Razzak, Mrinank Sharma, Andreas Kirsch, Winnie Xu, Benedikt Höltgen, Aidan N. Gomez, Adrien Morisot, Sebastian Farquhar, Yarin Gal

Link

This paper describes an example selection strategy for training machine learning models. They try to estimate how much would training on an additional data point $(x_i, y_i)$ reduce the loss on a hold-out set $ho$, without actually training on this data point and then evaluating on the hold-out set.

Their approximation ends up being simple (Eq 2): they essentially use a Bayesian argument to say that the loss on the hold-out set given this additional training data point is proportional to the difference between two terms (one from the numerator and the denominator of the Bayesian term): the loss on this data point $(x_i, y_i)$ if you had trained on the hold-out set additionally to the dataset, minus the loss on $(x_i, y_i)$ after training on just the original dataset (my explanation uses "loss", which has a reverse sign compared to their description in Eq 2.).

This is still expensive to compute, but then they make a series of approximations. First, instead of actually conditioning on the dataset (the ideal Bayesian interpretation of training; $p(\theta | \mathcal{D}) \propto p(\mathcal{D}|\theta) p(\theta)$), they simply train their models on it with SGD. Second, instead of considering a model trained additionally on the hold-out set, they trade that for a model trained just on the hold-out set. Finally, they swap out the actual model by a much smaller one that is just trained on $ho$.

They call the loss that you would get on $(x_i, y_i)$ by training additionally on $ho$ as the irreducible loss. The reducible loss is then the difference between the model's train loss on $(x_i, y_i)$ and the irreducible loss. This is equation 3, which defines the RHO loss (for Reducible Hold-Out loss).

Intuitively, the effect of this is simple and appealing: it should prioritize points with high train loss but low irreducible loss. What are those? It's easier to reason about the alternatives. In the limit of number of examples, if most of a point's loss is irreducible (thus train loss will be close to irreducible loss, and it will have low RHO loss), that would mean that additional training cannot reduce loss on this point by much; it's likely noisy or an outlier, therefore either not worth learning or not learnable. If the point has low train loss, it will also have low irreducible loss, and it is thus a point that has /already been learned/. The negation of these three bad cases is the remaining case (high train loss and low irreducible loss), and hence the title of the paper.

The experiments show that selecting examples by their RHO loss makes models train with much fewer steps (up to 18x), and sometimes to slightly better accuracies. The interesting catch, though, is that computing the train loss requires running the full model on the examples first; then, they run the smaller IL model to estimate the irreducible loss, and finally run a full forward-backward pass on the selected examples. Thus, they still have to at least run a forward pass on all the examples.

The experiments measure speed-up in terms of epochs or training steps, but each of their steps is potentially much more expensive. They say that this is mostly due to implementation details and hardware that are outside the scope of the paper, but I think it's an important thing to at least report. If you take 10x less steps but each step takes 10x more time, is it fair to just report the former? Even if their method did not win in wall-clock time, I'd still think it's an interesting investigation, but the fact that there isn't a single time measurement in the entire paper makes me wary of the impact of these "speed-up" claims in whatever I'll end up doing in practice.