Looped Transformers for Length Generalization
Ying Fan, Yilun Du, Kannan Ramchandran, Kangwook Lee
UW-Madison, MIT, UC Berkeley
[Paper]
Ying Fan, Yilun Du, Kannan Ramchandran, Kangwook Lee
UW-Madison, MIT, UC Berkeley
[Paper]
Recent work has shown that Transformers trained from scratch can successfully solve various arithmetic and algorithmic tasks, such as adding numbers and computing parity. While these Transformers generalize well on unseen inputs of the same length, they struggle with length generalization, i.e., handling inputs of unseen lengths. In this work, we demonstrate that looped Transformers with an adaptive number of steps significantly improve length generalization. We focus on tasks with a known iterative solution, involving multiple iterations of a RASP-L operation—a length-generalizable operation that can be expressed by a finite-sized Transformer. We train looped Transformers using our proposed learning algorithm and observe that they learn highly length-generalizable solutions for various tasks.
Visualization of the n-RASP-L solutions for Copy, Parity, and Addition with n = 2. Copy is implemented by n iterations of shifting; Parity is implemented by n iterations of shifting and XOR; Addition is implemented by n + 1 iterations of shifted XOR and AND.
Figure 1: Method Overview
During training, we supervise the output of the model to match the target data only after the number of steps needed by applying the same decoder block iteratively, helping the model learn intermediate steps that can be reused and can handle input of arbitrary lengths. All grey blocks share the same parameters.
During inference, we can adaptively adjust the number of steps based on maximum confidence or a predetermined number of steps.
Examples are from the Copy task with $n$ symbols. "#" indicates EOS, "*" indicates ignored output, and ">" indicates the end of the query (EOQ).
Looped Transformers with an adaptive number of steps significantly improve the length generalization performance compared with baselines. (NTP indicates vanilla next-token prediction; NTP-Pause indicates next-token prediction with pause tokens; NTP-Loop indicates next-token prediction with a fixed number of weight-tied layers.)
@article{fan2024looped,
title={Looped Transformers for Length Generalization},
author={Ying Fan and Yilun Du and Kannan Ramchandran and Kangwook Lee},
journal={arXiv preprint arXiv:2409.15647}
year={2024},
}