logo
Published on

Exploration of Inference Speed of 🤗 Transformers Auto-Regressive Model

Authors
drawing
Bing Image Creator "cheetah running down a road i Pittsburgh, Nikon"

With the 🤗 Transformers library and PyTorch, we have several easily available methods for model inference that can help increase model efficiency. Interestingly, we'll find that the concept of "efficiency" is a blanket statement referring to both memory and wall-clock efficiency. Some methods are more memory efficient but have no change (or even a negative impact) to inference speed. In this blog we'll explore the inference speed of several popular techniques on the NVIDIA A10G GPU.

Experimental Setup

We'll use the Mistral 7B model as our model for inference. This is an auto-regressive model with 7 billion parameters.

Because the model has not been fine-tuned for a task, it's not specifically trained to output the end-of-sequence token, meaning that it will generally continue to generate text until we cut it off. We won't concern ourselves with the quality of the generated text, but only the inference speed of the model. Auto-regressive models are next-word predictors, and their speed of predicting the next token doesn't change based on the difficulty of the task. Because of this, we set a maximum length of 1000 tokens for all of our experiments. All models generate 1000 tokens, and we report their performance in terms of how many tokens per second they can generate. The inference speed is effected by the decoding strategy of the model; for this experiment we'll perform greedy decoding in all scenarios, meaning that the model will always choose the token with the highest probability at each step. This is the fastest decoding strategy and results in a reproduceable result for each generation.

Most experiments are performed on a single NVIDIA A10G GPU. We also experiment with a 4 NVIDIA A10G GPU cluster.

Because the computations take place on the GPU, using the python time.time() as a timer leads to inaccurate results. We follow the steps described in this blog, which helps ensure the following:

  • GPU warm up before timing the inference speed
  • Use PyTorch to synchronize the CPU and GPU when timing for accurate results
  • Run each model configuration 50 times and take the average inference speed

Configurations

We run the test with the following configurations. In all cases the computation is performed on the GPU.

  • Model with no optimizations
  • Model with torch.no_grad context manager, meaning that weight gradients are not tracked, since we are not training the model
  • Model with torch.compile, meaning that the model is JIT-compiled into optimized kernels
  • Model run with Flash Attention 2.0, a new attention mechanism that is more memory efficient (link)
  • Model run with Flash Attention 2.0 and torch.compile
  • Model run with 8-bit quantization bitsandbytes
  • Model run with 4-bit quantizationbitsandbytes
  • Model run with 1 GPU using the HuggingFace text-generation-inference server
  • Model run with 1 GPU using the HuggingFace text-generation-inference server with custom kernels disabled
  • Model run with 4 GPUs and Tensor Parallelism using the HuggingFace text-generation-inference server

Results

MethodTokens per second
Outside of torch.no_grad decoratore25.97
Using torch.no_grad context26.03
Compiled26.13
Flash attention model20.73
Compiled flash attention model20.8
8bit quantized model4.84
4bit quantized model22.61
1 GPU TGI model29.95
4 GPU TGI model, Tensor Parallel56.43

Several interesting results:

Using the torch.no_grad() context manager doesn't speed up inference

I was surprised to find that using torch.no_grad doesn't improve the inference speed of my model with a batch size of 1. I understand that using the no_grad() context disables gradient tracking, which certainly decreases the memory overhead, but I thought that it would also yield some speed improvements. The inference speed increases very slightly but is within the margin of error. I found https://discuss.pytorch.org/t/torch-no-grad-doesnt-result-in-speed-boost/72561 which confirms my findings by saying that since there is no backward pass of the model, there is no change to inference speed.

Quantized Models aren't faster

I was fascinated to find that quantized models are actually much slower at generation! Turns out, they are more memory efficient, but not faster. This is a known limitation of bitsandbytes (issue here), but is not clearly mentioned in Transformers documentation. Quantized models are expectedly slower by about 15%.

PyTorch Compilation isn't faster

Another finding was that pytorch compilation doesn't speed up inference on the A10G. According to this documentation, the inference speed ups is dependent on the GPU. From my looking online, most benchmarking seems to be done on the NVIDIA A100s, so it's possible that the A10G doesn't benefit from compilation.

Flash Attention isn't faster

Similar to torch.compile(), I struggled to find any benchmarking done for the A10G GPU. The benchmarking documented in the Flash Attention Repo only reports results on the H100 and A100. It's possible that the A10G doesn't benefit from Flash Attention in terms of text generation speed. It's almost certainly more memory efficient, but I didn't measure memory usage in this test.

The text-generation-inference server is faster, and Tensor Parallelism is fastest

The text-generation-inference server is managed by Huggingface and provides a streamlined docker image for running inference on a model. Even when running on the same A10G GPU, it's faster than any other methods that don't utilize the server.

Another unique feature of the text-generation-inference server is its support for Tensor Parallelism. This design actually splits the model computations across multiple GPUs in order to increase computation speed. This differs from the default behavior of transformers, which can support spreading the model across multiple GPUs, but only for the purpose of increasing the batch size. This is a very interesting feature, and it results in an almost 2x speed up over the single GPU model!

Conclusion

After running these tests, I have a few things that I would like to look at in the future. First, now that we've established that most of these methods don't have inference speed improvements, do they have memory usage improvements? Second, why is the text-generation-server so much faster? I need to look more at the internals of that repository to understand what they are doing that makes it so much faster than the other methods even when using only 1 GPU (aka not using Tensor Parallelism). Finally, I'd like to look more into Tensor Parallelism and how it works.

Thanks for following along!

Code

This is the code used to run the test, I use AWS Sagemaker for the text-generation-inference server.

import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import sagemaker
import numpy as np

# turn off transformers tqdm
from transformers.utils import logging
from transformers import BitsAndBytesConfig
import json
from sagemaker.huggingface import HuggingFaceModel

# Turn off logging and tqdm progress bars
logging.set_verbosity_error()
logging.disable_progress_bar()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_id = "mistralai/Mistral-7B-v0.1"

# LM doesn't have a trained EOS token so it will go to max_length
generation_config = {
    "num_beams": 1,
    "max_new_tokens": 1000,
    "do_sample": False,
}
# This is the text that will start the LM generation
text = "The following is a extremely long 24 paragraph essay about the detailed plot of Romeo and Juliet:\n\n"
# Tokenizer for all models
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
inputs = tokenizer(
    text,
    return_tensors="pt",
    return_attention_mask=False,
).to("cuda")


# Help from https://deci.ai/blog/measure-inference-time-deep-neural-networks/
def run_model(model):
    model.eval()
    # INIT LOGGERS
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(
        enable_timing=True
    )
    repetitions = 50
    timings = np.zeros((repetitions, 1))
    # GPU-WARM-UP
    for _ in range(10):
        _ = model.generate(**inputs, max_length=35)

    with torch.inference_mode():
        for rep in range(repetitions):
            starter.record()
            outputs = model.generate(**inputs, **generation_config)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            assert len(outputs[0]) == generation_config["max_new_tokens"] + len(
                tokenizer(text)["input_ids"]
            )
            timings[rep] = curr_time

    mean_syn = np.sum(timings) / repetitions
    print(f"Average time: {mean_syn:0.2f} ms")
    tok_per_sec = generation_config["max_new_tokens"] / (mean_syn / 1000)
    print("Generate tok/s: ", round(tok_per_sec, 2))


# Worst case, GPU while tracking gradients
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="eager",
).to("cuda")
print("Running model while tracking gradients")

run_model(model, tracking_gradients=True)
model = None
torch.cuda.empty_cache()
# Worst case, GPU while tracking gradients
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="eager",
).to("cuda")
print("Running model without tracking gradients")
run_model(model)
model = None
torch.cuda.empty_cache()


# Compile the model with torch compile
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="eager",
).to("cuda")
model = torch.compile(model, mode="max-autotune", backend="inductor")
print("Running compiled model without tracking gradients")
run_model(model)
model = None
torch.cuda.empty_cache()

# Load model and use flash attention
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
).to("cuda")
print("Running model with flash attention")
run_model(model)
model = None
torch.cuda.empty_cache()

# Compile the flash attention model with torch compile
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
).to("cuda")
model = torch.compile(model, mode="max-autotune", backend="inductor")
print("Running model with compiled flash attention")
run_model(model)
model = None
torch.cuda.empty_cache()

# Load an 8bit quantized model
quant_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.float16,
)

# load model in 8bit
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quant_config,
)
print("Running model with 8bit quantization")
run_model(model)
model = None
torch.cuda.empty_cache()

# load model in 4bit
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quant_config,
)
print("Running model with 4bit quantization")
run_model(model)
model = None
torch.cuda.empty_cache()

"""
Finished with local generation, now run generation on AWS Sagemaker with Text-Generation-Inference Server.
It's ok to use time.time() here because the server is running on a different machine and we don't have to worry about CPU/GPU sync
"""
# copied and edited from https://www.philschmid.de/sagemaker-deploy-mixtral
sess = sagemaker.Session()
sagemaker_session_bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

# Define Model and Endpoint configuration parameter
config = {
    "HF_MODEL_ID": model_id,  # model_id from hf.co/models
    "SM_NUM_GPUS": json.dumps(1),  # Number of GPU used per replica
    "MAX_INPUT_LENGTH": json.dumps(1000),  # Max length of input text
    "MAX_BATCH_PREFILL_TOKENS": json.dumps(
        2100
    ),  # Number of tokens for the prefill operation.
    "MAX_TOTAL_TOKENS": json.dumps(
        2100
    ),  # Max length of the generation (including input text)
    "MAX_BATCH_TOTAL_TOKENS": json.dumps(
        4000
    ),  # Limits the number of tokens that can be processed in parallel during the generation
}

# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(hf_llm_image_ver="1.3.3", env=config)

# Deploy model to an endpoint
# https://sagemaker.readthedocs.io/en/stable/api/inference/model.html#sagemaker.model.Model.deploy
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    container_startup_health_check_timeout=600,
)

# Generation arguments
start = time.time()
chat = llm.predict({"inputs": text, "parameters": generation_config})
end = time.time()
tokens = tokenizer(chat[0]["generated_text"][len(text) :], return_tensors="pt")
tokens_per_sec_sagemaker_1 = (len(tokens[0]) - len(tokenizer(text))) / (end - start)
print("Generate tok/s of sagemaker model: ", round(tokens_per_sec_sagemaker_1, 2))
llm.delete_model()
llm.delete_endpoint()

config = {
    "HF_MODEL_ID": model_id,  # model_id from hf.co/models
    "SM_NUM_GPUS": json.dumps(4),  # Number of GPU used per replica
    "MAX_INPUT_LENGTH": json.dumps(1000),  # Max length of input text
    "MAX_BATCH_PREFILL_TOKENS": json.dumps(
        2100
    ),  # Number of tokens for the prefill operation.
    "MAX_TOTAL_TOKENS": json.dumps(
        2100
    ),  # Max length of the generation (including input text)
    "MAX_BATCH_TOTAL_TOKENS": json.dumps(
        4000
    ),  # Limits the number of tokens that can be processed in parallel during the generation
}
llm_model = HuggingFaceModel(hf_llm_image_ver="1.3.3", env=config)
llm = llm_model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.12xlarge",
    container_startup_health_check_timeout=600,
)

# Generation arguments
start = time.time()
chat = llm.predict({"inputs": text, "parameters": generation_config})
end = time.time()
tokens = tokenizer(chat[0]["generated_text"][len(text) :], return_tensors="pt")
tokens_per_sec_sagemaker_2 = (len(tokens[0]) - len(tokenizer(text))) / (end - start)
print("Generate tok/s of sagemaker model: ", round(tokens_per_sec_sagemaker_2, 2))
llm.delete_model()
llm.delete_endpoint()