Yanhao Jin, Krishnakumar Balasubramanian and Lifeng Lai
We investigate the in-context learning capabilities of transformers for the d-dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically,
we prove that there exists a transformer capable of achieving a prediction error of order O(\sqrt{d/n}) with high probability, where n represents the training prompt size in the high signal-to-noise ratio (SNR) regime.
Moreover, we derive in-context excess risk bounds of order O(L/\sqrt{B}) for the case of two mixtures, where B denotes the number of training prompts, and L represents the number of attention layers. The dependence of L on the SNR is explicitly characterized, differing between low and high SNR settings.
We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum.
Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.