This paper presents a solution to scale up the attention mechanism to very long documents. Ultimately, naïvely growing attention window is expensive even with the approximations that make the complexity sub-quadratic (e.g., Longformer, Big Bird). This paper presents an interesting alternative: explicitly memorizing hidden states in a non-differentiable module when going through the document in small blocks, then adding an attention head that retrieves from that module using an approximate k-NN query. Even though the module is non-differentiable, its keys are taken from a possibly outdated, but ultimately similar distribution to those generated from the local attention window (since the stored keys were local at some point in the past).
This is an algorithmically simple and appealing idea. The authors show some examples of the model clearly memorizing and attending to sections that are tens of thousands of tokens apart, which is cool. I'd predict that this requires a certain scale to begin to do something meaningful, since current models still struggle with predictions that could be decided just from the local attention window itself (so global coherence is often not the immediate bottleneck). But it's nice that this is a possible mechanism to try whenever using language models with examples that are unbounded in length, like source code.