Code Monkey home page Code Monkey logo

speculative-decoding's Introduction

Speculative Decoding

This repository is a pytorch implementation of Speculative Decoding / Speculative Sampling (Leviathan et al., 2023;Chen et al., 2023). It contains the code for three generation strategies: classic auto-regressive decoding, beam search decoding (with length penalty) and speculative decoding. Auto-regressive decoding and Speculative Decoding can be used in a greedy or nucleus sampling (temperature, top k and top p) setting.

Example of generation.

What is Speculative Decoding?

Speculative Decoding is a decoding strategy for transformers that allows to generate sequences faster than the classic auto-regressive decoding without changing the output distribution or requiring further fine-tuning. It uses a smaller, more efficient approximation model (called a "drafter") to generate speculative token prefixes. These prefixes are then evaluated in parallel by the larger target model, reducing the number of serial decoding steps required and leading to inference speedups.

The core process rely on the specific behavior of the Transformer model that allows to compute the probability distribution of all the fed in tokens. This distribution is then used to verify the drafts generated by the drafter model.

How to use

0. Installation

This project requires Python 3.7 or later and the following dependencies:

rich
tqdm
termcolor
tokenizers==0.19.1
torch==2.3.0
transformers==4.41.1
accelerate==0.30.1
bitsandbytes==0.43.1

Simply fork this repository and install the dependencies.

1. Generate text using Speculative Decoding

a. Load the target and drafter model

The target model is the transformer model we want to accelerate, while the drafter model is the smaller model that will be used to generate drafts to the target model.

Here are some requirements to make speculative decoding work:

  • The target model must be a transformer model (decoder only or encoder-decoder).
  • The drafter model must share the same tokenizer as the target model.
  • The target model and the drafter model should output same shape logits.
  • The target model should be large enough to benefit from the acceleration. (causing a bottleneck in memory)
  • The drafter model should be small enough to be faster than the target model.
from transformers import AutoTokenizer, AutoModelForCausalLM

# We will use the Google Gemma 2 27B Instruct as the model we want to accelerate (27B parameters)
target_model_name = "google/gemma-2-27b-it"
target = AutoModelForCausalLM.from_pretrained(target_model_name)

# We will use the Google Gemma 2 9B Instruct as the drafter model (9B parameters)
drafter_model_name = "google/gemma-2-9b-it"
drafter = AutoModelForCausalLM.from_pretrained(drafter_model_name)

# Don't forget to load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

b. Prepare the input

Before generating text, we need to prepare the input. The input should be tokenized and encoded using the tokenizer.

prefix = "Translate to English: Je m'appelle Romain. N'hésitez pas à contribuer à mon projet !"

chat_templated = f"<bos><start_of_turn>user\n{prefix}<end_of_turn>\n<start_of_turn>model\n" # Gemma chat template
input_ids = tokenizer(chat_templated, return_tensors="pt").input_ids
input_ids = input_ids[0].tolist() # Generation methods require a list of ids

c. Generate text

Speculative Decoding uses one hyperparameter: $\gamma$, the number of drafts generated by the drafter model at each step.

Increasing the value of $\gamma$ will not always lead to a faster generation, as the drafts may be rejected more. The acceptance rate $\alpha$ is the number of drafts accepted by the target model divided by the number of drafts generated. The higher the acceptance rate, the faster the generation. So the idea is to find the ideal $\gamma$ according to the acceptance rate in order to get the fastest generation.

from sampling import speculative_generate, autoregressive_generate
# from sampling import speculative_generate_encoder_decoder, autoregressive_generate_encoder_decoder
from utils.logits_processors import NucleusProcessor

# Parameters
gen_len = 100       # Maximum number of tokens generated (could over pass when using speculative decoding)
gamma = 4           # Number of drafts generated by the drafter model at each step
logits_processor = NucleusProcessor(temperature=.6, top_p=.9) # Nucleus sampling with p=0.9 and T=0.6

# Generate text using the classic auto-regressive decoding (slow)
output_ids_ar = autoregressive_generate( # or autoregressive_generate_encoder_decoder for encoder-decoder models
                input_ids,
                target,
                logits_processor=logits_processor,
                max_gen_len=gen_len,
                end_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
output_ar = tokenizer.decode(output_ids_ar, skip_special_tokens=True)

# Generate text using the speculative decoding (faster)
output_ids_sd, alpha = speculative_generate( # or speculative_generate_encoder_decoder for encoder-decoder models
                input_ids,
                drafter,
                target,
                logits_processor=logits_processor,
                gamma=gamma,
                max_gen_len=gen_len,
                end_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
output_sd = tokenizer.decode(output_ids_sd, skip_special_tokens=True)

print("Auto-regressive decoding:", output_ar)
print("Speculative decoding:", output_sd)
print("Acceptance rate:", alpha) # Number of drafts accepted by the target model divided by the number of drafts generated

To use Beam Search Decoding, you can use the beam_search_generate function. The beam_search_generate function requires top_k (number of tokens to evaluate at each branch), num_beams (number of beams that run in parallel), min_length and alpha (for length penalty) hyperparameters.

from sampling import beam_search_generate # Beam Search Decoding is not compatible with encoder-decoder models yet.

output_ids_bs = beam_search_generate(
                input_ids,
                target,
                max_gen_len=gen_len,
                end_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                top_k=3,
                num_beams=5,
                min_length=5,
                alpha=1.2,
            )

2. Run console interface Inference

You can run infer.py in your console to generate text using the console interface. You can easily change the hyperparameters of the generation, compare target and speculative generation, enable drafter generation and much more.

python infer.py

To change the models used, you can change the target_model_name and drafter_model_name in the infer.py file. Be careful to change the generate methods to encoder-decoder models if you are using encoder-decoder models.

Did you find any bug?

Please open an issue or submit a pull request if you find any bug. Contributions are welcome!

References

[1] Leviathan, Y., Kalman, M. & Matias, Y.. (2023). Fast Inference from Transformers via Speculative Decoding. Proceedings of the 40th International Conference on Machine Learning, in Proceedings of Machine Learning Research 202:19274-19286 Available from https://proceedings.mlr.press/v202/leviathan23a.html.

[2] Chen, C., Borgeaud, S., Irving, G., Lespiau, J. B., Sifre, L., & Jumper, J. (2023). Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.