Retrieval-Based Speculative Decoding (REST): A Plug-and-Play Method for Accelerating Language Model Generation without Additional Training

Zhenyu He*, Zexuan Zhong*, Tianle Cai*,  Jason D. Lee, Di He (* Equal contribution)

Recent advancements in accelerating the generation process of Large Language Models (LLMs) like speculative decoding, blockwise parallel decoding, and Medusa have brought impressive speed improvements. Typically, these methods rely on pairing the large base model with a lightweight draft model. The draft model tries to predict multiple tokens per decoding step with lower latency and lets the base model verify them in parallel, minimizing the number of inferences needed from the slower base model.

However, obtaining a high-quality draft model remains an art: It must balance small size and strong predictive power while matching the vocabulary of the base model; also, it should be friendly to integrate into a distributed system for serving. To tackle these challenges, Medusa introduced an efficient fine-tuning to create draft models in the form of additional language model heads. However, the requirement for additional fine-tuning still receives many complaints.

This begs the question - can we design an acceleration method that is plug-and-play out-of-the-box? One that delivers swift generation without the need to train or fine-tune new models?

To meet this goal, we introduce REST, a retrieval-based speculative decoding approach. Instead of a learned draft model, REST utilizes a datastore to retrieve draft tokens based on partial inputs. We built a suffix array on the datastore to enable super fast exact-match retrieval operating on CPUs -- this also means REST avoids extra GPU load. As it requires no joint training or fine-tuning, REST can accelerate any pre-trained language model instantly.

In this post, we'll explore REST's inspiration and inner workings, revealing how it achieves impressive speedups without extra-trained models. Clear explanations and vivid examples will demonstrate REST's capabilities.

The implementation is available at this repository for hands-on interaction. You can also view our paper for additional technical details presented engagingly.

Building the "Code Completion Plugin" for LLMs

To construct a training-free draft model, we found inspiration in classic code completion tools. These aids boost programmer productivity by predicting common phrases and patterns, typically using techniques like static analysis, symbol tables, and retrieval. Even before the rise of LLM-based assistants like GitHub Copilot, code completion significantly accelerated development.

For REST, we adapted the most general code completion technique - retrieval - to build a "code completion plugin" for LLMs. By finding common patterns in a datastore, retrieval-based completion helps generate repeating code snippets. It also enables predicting everyday expressions used routinely. We present the overview of the REST pipeline below.

REST: Retrieval-Based Speculative Decoding

Overview of REST (Retrieval-Based Speculative Decoding). During inference, the input context is utilized as the query to retrieve docs from the datastore that match the longest suffix of the input using a suffix array. A Trie is then constructed using the continuations from the retrieved docs. We prune the low-frequency (weight) branches and the remaining subtree is further used as draft tokens in speculative decoding. The draft tokens will be fed into the LLM with a tree attention mask for verification. All correct tokens from the start will be accepted, and the draft tokens after the first mistake will be rejected.

REST speeds up code and text generation

We conduct extensive experiments to test the efficiency and effectiveness of REST in different scenarios. For the code domain, we use a portion (2.7M samples) of Python pretraining code from The Stack as the datastore (27GB), and accelerate CodeLlama 7B and 13B, respectively. The results show that on HumanEval, REST can achieve a 2.12× to 2.36× speed-up rate. For the general domain, we construct a datastore using UltraChat, which contains around 774k conversations (12GB). The results show that on MT-Bench, REST can accelerate 7B and 13B Vicuna for 1.62× to 1.77×, respectively. 

Analysis

Effect of the size of the datastore

Increasing the size of the datastore is an effective strategy for enhancing the accuracy of retrieved draft tokens in the Trie, which in turn can significantly boost generation speed. We assume that in industry applications, there will be ample disk storage to build a large datastore. This figure shows that as the datastore size increases, both the Mean Generated Length (how many tokens generated in one forward pass) and Mean Token Time (how long does it take to generate one token on average) correspondingly improve. From the trend of scaling the retrieval datastore size, we expect that there is still potential for further speedups with a larger datastore.

Effect of draft token selecting strategies

We compare selecting nodes based on pruned Trie as draft tokens with randomly sampling retrieved instances as draft tokens. This table shows that selecting nodes from the Trie, as opposed to employing a random sampling approach, enhances the performance.

A plug-and-play method for any LLM!

REST is a plug-and-play method that can be easily integrated with language models of any size, vocabulary, or architecture! Below shows more results of applying REST on deepseek-coder, Llama2, and Llama2-Chat!