An Explanation of In-context Learning as Implicit Bayesian Inference (@ ICLR 2022)
Sang Michael Xie, Aditi Raghunathan, Percy Liang, Tengyu Ma
This paper provides an explanation of in-context learning in language models
through the claim that models learn to implicitly perform Bayesian inference to infer
a latent concept from their prompts.
One observation that makes in-context learning intriguing is that language models are trained on
sequences that qualitatively differ from the prompts we use at test-time.
In particular, sequences of raw input-output examples are quite unusual for
the Web, but LMs can complete them with correct (or at least on-topic) outputs nonetheless.
The paper setup is the following:
- They first assume that documents are generated in a hierarchical process where first a "concept" $\theta$ is sampled from a prior distribution $p(\theta)$, then the sequence of tokens $t_1, \cdots, t_k$ is sampled given the concept. Here, $p(t_1, \cdots, t_k | \theta)$ is modelled by an HMM.
- Then, they assume that prompts follow a certain structure of input-output pairs separated by certain delimiter tokens (Eq 3), and crucially that the input/output pair distribution at any point in the sequence is close to the distribution at the start of the prompt (their Assumption 3).
- Finally, they propose a Bayesian predictor for $p(y|x)$ (Eq 4), which marginalizes over hidden concepts.
Their analysis essentianlly shows that one of two things happen:
- *Distinguishable case:* If the latent concept is sufficiently different from other concepts in the hypothesis space, then even if the prompt is an unlikely sequence given the prompt, it is sufficiently less likely for other concepts in a way that the Bayesian predictor basically "picks out" the right concept to make the prediction as the number of examples given grows. In this case, the likelihood ratio of the prompt given $\theta^*$ and some other concept $\theta$ goes to infinity (or log-likelihood ratio $\rightarrow -\infty$).
- *Non-distinguishable case:* If there are other concepts that also put a similar probability mass on the prompt, then the output as predicted from these other concepts would have low error.
They then make a synthetic dataset where concepts and documents follow their HMM-induced distribution from the theory, and show that both Transformers and LSTMs can perform in-context learning, and that model scale increases in-context performance even if the training loss stays the same. This last bit is an interesting observation even in this small-scale setting.
One thing I missed from the paper's argument was the link between Transformers trained via maximum likelihood and their Bayesian predictor. As far as I followed, the argument is that the Bayesian predictor would perform in-context learning and achieve optimality in $0-1$ loss. But this is not conceptually the same as saying that "any model that is $0-1$ loss-optimal is equivalent to the Bayesian predictor", even though I do in general buy their overall take on what is even in-context learning.