⭐ This blog post describes dynamic speculative decoding. This is a new method developed by Intel Labs and Hugging Face that speeds up text generation by up to 2.7x, depending on the task. This method is the default mode of operation for auxiliary generation starting from Transformers🤗 release 4.45.0 ⭐
speculative decoding
Speculative decoding is a common technique for speeding up inference while preserving the accuracy of large language models. Speculative decoding works by splitting the generation process into two stages, as shown in the diagram below. In the first stage, a fast but less accurate draft model (also known as an assistant) autoregressively generates a set of tokens. In the second stage, a larger but more accurate target model performs parallel validation on the generated draft tokens. This process allows the target model to generate multiple tokens in one forward pass, thus speeding up autoregressive decoding. The success of speculative decoding is highly dependent on speculative lookahead (SL), the number of tokens generated by the draft model in each iteration. In practice, SLs are either static values or based on heuristics, neither of which is optimal for maximizing performance during inference.
Iterative speculative decoding.
dynamic speculative decoding
Transformers🤗 provides two different ways to determine the schedule for adjusting the number of draft (assistant) tokens during inference. A direct method based on Leviathan et al. uses a static value of guess-ahead and generates a fixed number of candidate tokens for each guess iteration. Alternatively, a heuristic-based approach adjusts the number of candidate tokens for the next iteration based on the acceptance rate of the current iteration. If all speculative tokens are correct, the number of candidate tokens increases. otherwise it will decrease.
We anticipate that enhanced optimization strategies for managing the number of draft tokens generated may further reduce latency. To test this theory, we utilize an oracle that determines the optimal speculative lookahead value for each speculative iteration. The oracle uses the draft model to autoregressively generate tokens until there is a discrepancy between the predicted tokens of the draft model and the target model. This process is repeated for each speculative iteration, ultimately identifying the optimal (maximum) number of draft tokens to be accepted for each iteration. Draft/target token mismatches are identified using the zero-temperature rejection sampling algorithm introduced by Leviathan et al. This oracle maximizes the potential of speculative decoding by generating the maximum number of valid draft tokens at each step and minimizing the number of calls to both the draft and target models.
The figure below on the left shows the oracle and static speculative lookahead values across speculative iterations of the code generation example from the MBPP dataset. A large dispersion is observed in the oracle speculative lookahead values (orange bars). Static speculative lookahead (blue bar), where the number of draft tokens generated is fixed at 5, performs 38 target forward passes and 192 draft forward passes, while oracle speculative lookahead performs 27 target forward passes and 129 draft forward passes only. Significant reduction. The image to the right shows oracles and static speculative lookahead across the Alpaca dataset.
Oracle and static speculative read ahead (SL) values in one MBPP example.
Average oracle speculative lookahead across Alpaca dataset.
Both figures show large variations in oracle speculative lookahead values, suggesting that static speculative lookahead may not be optimal.
To get even closer to Oracle and even faster, we developed a simple way to dynamically adjust the guess-ahead value at each iteration. After generating each draft token, decide whether to continue generating the next token in the draft model or switch to the target model for validation. This decision is based on the reliability of the assistant model’s predictions, as estimated by logit softmax. When the confidence of the assistant model in predicting the current token falls below a predefined threshold called assistant_confidence_threshold, the token generation process for that iteration is stopped even if the maximum number of speculative tokens num_assistant_tokens has not been reached . When stopped, the draft tokens generated during the current iteration are sent to the target model for validation.
benchmark
We benchmarked dynamic and heuristic approaches across different task and model combinations. The dynamic approach showed superior performance in all tests. In particular, a speedup of up to 1.52x was observed using the dynamic approach with Llama3.2-1B as an assistant to Llama3.1-8B, whereas a heuristic approach showed a significant speedup for the same setup. was not seen. Another observation is that codegen-6B-mono is slower when using the heuristic approach, whereas it is faster when using the dynamic approach.
Target Model Draft (Assistant) Model Task Acceleration – Heuristic Acceleration – Dynamic facebook/opt-6.7b facebook/opt-125m Summary 1.82x 2.71x facebook/opt-6.7b facebook/opt-125m Open-end generation 1.23 x 1.59x Salesforce/codegen-6B-mono Salesforce/codegen-350M-mono Code generation (Python) 0.89x 1.09x google/flan-t5-xl google/flan-t5-small Summary 1.18x 1.31x metal-llama/Llama-3.1-8B metal-llama/Llama- 3.2-1B Summary 1.00x 1.52x Meta-llama/Llama-3.1-8B metal-llama/Llama-3.2-1B Open-end generation 1.00x 1.18x metal-llama/Llama-3.1-8B metal-llama/Llama-3.2-1B Code generation (Python) 1.09x 1.15x
code
Dynamic inference has been integrated into release 4.45.0 of the Hugging Face Transformers library and now serves as the default mode of operation for decoding assistance. You do not need to change your code to use assisted generation with dynamic inference. Just run your code as usual.
from transformer import AutoModelForCausalLM, AutoTokenizer
import torch prompt = “Alice and Bob”
Checkpoint = “EleutherAI/pythia-1.4b-deduped”
assistant_checkpoint = “EleutherAI/pythia-160m-deduped”
device = “Cuda” if torch.cuda.is_available() Other than that “CPU”
tokenizer = AutoTokenizer.from_pretrained(checkpoint) inputs = tokenizer(prompt, return_tensors=“pt”).to(device) Model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device) Assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to(device) Output = model.generate(**inputs,assistant_model=assistant_model)
The default dynamic guess lookahead parameters reflect optimal values, but can be adjusted to improve performance for a particular model pair or dataset using the following code.
Assistant_model.generation_config.assistant_confidence_threshold=0.4
Assistant_model.generation_config.num_assistant_tokens_schedule=‘Continuous’
Assistant_model.generation_config.num_assistant_tokens=20
To revert to a heuristic or constant (like Leviathan et al.) approach, set num_assistant_tokens_schedule to “heuristic” or “constant”, respectively, and assistant_confidence_threshold=0 and num_assistant_tokens=5 as follows:
Assistant_model.generation_config.num_assistant_tokens_schedule=‘heuristic’
Assistant_model.generation_config.assistant_confidence_threshold=0
Assistant_model.generation_config.num_assistant_tokens=5
What’s next?
We introduced a faster strategy for assistance generation called dynamic speculative decoding. It performs better than heuristic-based methods and extracts a fixed number of candidate tokens.
In a future blog post, we will introduce a new method for assist generation: combining an arbitrary target model with an arbitrary assistant model. This opens the door to speeding up countless models on Hug Face Hub that don’t have enough small assistant variants. For example, Phi 3, Gemma 2, CodeLlama, etc. are subject to speculative decoding. stay tuned!
References
quotation
@article{mamou2024accelerated, title={Accelerating speculative decoding using dynamic guess length}, author={Mamou, Jonathan and Pereg, Oren and Korat, Daniel and Berchansky, Moshe and Timor, Nadav and Wasserblat, Moshe and Schwartz, Roy}, journal ={arXiv preprint arXiv:2405.04304},year={2024} }