Considerations with LLM Inference - A cached story!
If we thought, having GPU provides us the ultimate power, think again
Overview
Recently, I was working with my team to deploy our own LLM service with a pretrained and quantized large size large language model. We had GPU memory constraints and therefore it became important to understand the whole ecosystem of decoding to be able to set the affecting parameter values according to the use case.
High-Level View of Decoder Inference
There are two phases in which the text generation happens in the LLM inference. The first is the initiation phase and the second is the generation phase.
Above is a simple decoding-only architecture high-level diagram. This shows the input is the token ids for the specific token embeddings. The decoder uses logits to select the current best token using different decoding strategies which are conditioned on the parameters of temperature, top_p, top_k and frequency penalty. Different decoding algorithms we are familiar with such as beam search and contrastive decodings are used here. Token generation strategy/decoding strategy/search strategy are all synonyms for the same.
General Steps in Inference Decoding
Text generation using language models like GPT involves tokenizing the input prompt and passing it through the model to generate the initial token.
This first step of generating the initial token is known as the initiation or pre-fill phase.
After initiation, the model enters the generation or decoding phase, where it iteratively appends the previously generated token to the input sequence and uses it to predict the next token.
This process continues until a stop sequence is encountered or a maximum length is reached, resulting in the complete generated text.
The computations involved in both phases are similar, with the model performing a forward pass on an incrementally growing sequence.
However, these computations are expensive, scaling quadratically with sequence length, leading to redundant calculations.
To optimize this, an optimization technique called KV caching is employed, which caches and reuses certain computations, introducing a critical difference between the initiation and decoding phases.
Advanced techniques like speculative sampling and lookahead decoding may deviate from this basic algorithm for improved efficiency and lower latency.
Understanding these phases and optimizations is crucial for efficient text generation using large language models.
KV Caching
Alright let’s do a recap of the transformer’s attention layer which will be required in the section.
Multi-Head Attention (MHA)
Multi-Head Attention is a key component of the Transformer architecture, which powers many state-of-the-art language models like BERT, GPT, and others. It is a mechanism that allows the model to attend to different parts of the input sequence in parallel, capturing long-range dependencies and different representations simultaneously.
The main idea behind MHA is to project the input queries, keys, and values into multiple subspaces (or heads) and perform the attention computation in each of these subspaces. The outputs from all the heads are then concatenated and linearly transformed to produce the final attended representation.
This multi-head approach offers several advantages over single-head attention:
1. It enables the model to attend to different positional relationships and representations within the input sequence simultaneously.
2. It provides a richer and more expressive attention mechanism, as each head can potentially learn different attention patterns.
3. It allows for more efficient parallel computation, as the attention computations across different heads can be executed in parallel.
The multi-head attention mechanism is highly parameterized and learns to focus on different parts of the input through training on large datasets, enabling the model to capture intricate relationships and dependencies within the data.
There are a lot of resources online that teach you the computation of attention scores. I consider the audience here to be aware of the same else have a look at this resource
Let’s understand the compute complexity and it’s bombardment :
Let's consider a single head and a single sequence in a batch of size b
with a total length t
(including the prompt and generated completions). The query tensor has a shape of (t, d_head)
, and the key tensor has a shape of (t, d_head)
.
To compute the attention scores, we perform a matrix multiplication between the query tensor and the transposed key tensor:
Attention_scores = Q * K^T
The shapes of the matrices being multiplied are (t, d_head) and (d_head, t).
Now, the number of FLOPs (floating-point operations) required for a matrix multiplication between two matrices of shapes (n, p) and (p, m) is approximately 2 * n * p * m.
In our case, where the shapes are (t, d_head) and (d_head, t), the number of FLOPs required for computing the attention scores for a single head and a single sequence is:
FLOPs = 2 * t * d_head * t = 2 * d_head * t^2
This operation needs to be repeated for each head (n_head) and each layer (n_layers) in the Transformer model. Additionally, we need to consider the batch size b.
Therefore, the total number of FLOPs required for computing attention scores across all heads, layers, and sequences in a batch is:
Total FLOPs = 2 * b * n_layers * n_head * d_head * t^2
= 2 * b * n_layers * d_model * t^2
Here, d_model is the dimensionality of the model, which is equal to n_head * d_head.
As you can see, the computation scales quadratically with the total sequence length t, which includes both the prompt and the generated completions.
Example:
Let's consider a batch size b=2, with a Transformer model having n_layers=12, n_head=12, and d_head=64 (resulting in d_model=768). If the total sequence length t=512, the total number of FLOPs required for computing attention scores would be:
Total FLOPs = 2 * 2 * 12 * 12 * 64 * 512^2
= 1,572,864,000
This quadratic scaling with sequence length highlights the importance of techniques like KV caching, which can help reduce redundant computations and improve the efficiency of the attention mechanism, especially for long sequences.
To understand the severity of quadratic scaling, let’s look at an example. To generate the 1,001st token, the model must perform 100x more FLOPs than to generate the 101st token. This exponential growth in compute obviously quickly becomes prohibitive. Fortunately for us and thanks to masking, a lot of computations can actually be spared between steps.
Let’s understand Intuition behind KV Caching
let's consider an example to explain the concept of KV caching and the difference between the initiation and decoding phases in the context of text generation using a Transformer-based language model.
Suppose we have a prompt: "The quick brown fox" and we want to generate the next word using the language model. Here's how the process would unfold:
Initiation Phase:
The prompt "The quick brown fox" is tokenized and passed through the Transformer model. The model generates the output representations for each token in the prompt, considering the context of all the previous tokens.
The output representation for the last token "fox" is used to predict the next token, let's say "jumps".
2. Decoding Phase:
In the next iteration, the new input sequence becomes "The quick brown fox jumps". Instead of processing the entire sequence again, the KV caching technique is employed. The key (K) and value (V) vectors for the tokens "The", "quick", "brown", and "fox" /are cached from the previous step. Only the query (Q) vector for the new token "jumps" needs to be computed. The attention mechanism uses the cached K and V vectors, along with the new Q vector for "jumps", to compute the output representation for "jumps". This output representation is then used to predict the next token in the sequence. The key advantage of KV caching is that it reduces redundant computations by reusing the previously computed key and value vectors. During the decoding phase, the model only needs to compute the query vector for the new token and fetch the necessary key and value vectors from the cache, instead of processing the entire input sequence again.
This optimization becomes increasingly important as the sequence length grows, as the attention mechanism has a quadratic complexity concerning the sequence length. By caching and reusing the key and value vectors, the computational cost is significantly reduced, especially for long sequences.
It's important to note that during the initiation phase, there are no previously computed key and value vectors to cache, so the model has to process the entire input sequence. However, in the decoding phase, the KV caching strategy introduces a fundamental difference in how the computations are performed, leading to improved efficiency and reduced redundancy.
Every time a new token is computed in the inference layer, the output token’s query embedding is passed for the attention computation + the Keys and Values are always kept in the cache for the previous time steps. Therefore, the number of dot product and matrix product operations are reduced and that’s why it is called KV Cache.
The transposed key tensor is still of shape (t, d_head). However, the query tensor is now of shape (d_head, 1). Single-head single-sequence attention scores computation therefore requires 2.d_head.t FLOPs and overall, attention computations require 2.b.n_layers.d_model.t FLOPs. Attention now scales linearly with the total sequence length!
Memory Size of KV Cache
consider generating the next word after the prompt "The quick brown fox". During the initiation phase, the model processes the entire prompt to generate the key and value tensors for each token. In the decoding phase, when generating the next token (e.g., "jumps"), the model only needs to compute the query tensor for "jumps" and fetch the cached key and value tensors for the previous tokens ("The", "quick", "brown", "fox"). This avoids recomputing these tensors, reducing the computational cost.
The KV cache size grows linearly with the batch size and total sequence length (prompt + generated text). For a batch size b, sequence length t, n_layers layers, n_heads attention heads, d_head head dimension, and precision p_a, the total KV cache size (in bytes) is:
b * t * n_layers * n_heads * (2 * d_head * p_a)
While KV caching improves efficiency, it trades memory for computation. The cache size grows with sequence length, posing memory management challenges, especially for long sequences with unknown lengths during inference.
Suppose we have a language model with the following configuration:
Batch size (b) = 4
Total sequence length (t) = 1024 (prompt + generated text)
Number of layers (n_layers) = 24
Number of attention heads per layer (n_heads) = 16
Dimension of each attention head (d_head) = 64
Precision (p_a) = 2 bytes (half-precision, FP16)
Using the formula for the total KV cache size:
Total KV cache size (bytes) = b * t * n_layers * n_heads * (2 * d_head * p_a)
Total KV cache size (bytes) = 4 * 1024 * 24 * 16 * (2 * 64 * 2)
= 4,194,304 * 256
= 1,073,741,824 bytes
= 1 GiB
To put this into perspective, modern GPUs like the NVIDIA A100 have up to 80 GB of memory. However, the KV cache size grows linearly with the batch size and sequence length, so for larger batches or longer sequences, the memory requirements can quickly become substantial.
How to Manage Your KV Cache
Let’s try to understand some nobes that can be tweaked and whether they have any effect on the same.
Once the input prompt has been processed, i.e. at the end of the prefill phase, we have already consumed both GPU memory (to store the key and the value tensors of each input token) and compute (to pass the prompt tokens through the model). Let’s have a look at some real numbers. Assuming the total FLOPs count of the forward pass of a P
parameters model is approximately 2.P
FLOPs/token [5], processing a prompt using Llama-2-7B consumes ~0.5 MB/token of GPU memory (cf. above) and ~14 GFLOPs/token of GPU compute. For a 1000 token prompt (a bit less than a two-pager), thats ~500MB of memory and 14 TFLOPs of compute and we have not generated anything yet. - Source reference
Multi Query Attention vs Multi-Head Attention vs Grouped Query Attention
Multi-Head Attention (MHA): This is the standard attention mechanism used in Transformers, where multiple attention "heads" are computed in parallel. Each head projects the input into separate query, key, and value spaces, computes the attention scores and output, and the results from all heads are concatenated and projected to form the final output. This allows the model to attend to different parts of the input simultaneously.
Multi-Query Attention (MQA): In MQA, all query heads share the same single set of key and value heads. In other words, all query heads compute their attention scores using the same set of keys, and all head outputs are computed using the same set of values (but with different attention scores). This reduces the memory footprint of the KV cache since there is only one set of keys and values instead of multiple sets for each head.
Grouped-Query Attention (GQA): GQA is a compromise between MHA and MQA. Instead of having all query heads share the same unique set of key and value heads (like MQA), the query heads are split into groups of size `g`. Query heads within the same group share the same unique set of key and value heads. This means that instead of having `n_heads` sets of key and value heads (like MHA) or just one set (like MQA), there are `n_heads/g` sets of key and value heads.
The key differences between these attention mechanisms are:
MHA has the highest memory footprint for the KV cache but preserves the full model capacity.
MQA has the lowest memory footprint for the KV cache but can potentially lead to a significant loss in model capacity, especially for larger models with many attention heads.
GQA provides a trade-off between memory footprint and model capacity by allowing a variable number of key and value head groups. When `g=n_heads`, it is equivalent to MHA, and when `g=1`, it is equivalent to MQA.
So, in summary, MHA is the standard approach, MQA is the most aggressive in reducing KV cache memory but may significantly impact accuracy, and GQA provides a tunable middle ground between the two extremes.
Paged Attention - to rescue memory wastage
The main inspiration behind PagedAttention was to address the inefficient memory utilization and fragmentation issues that arise when naively allocating memory for the key-value (KV) cache in large language models during inference.
Specifically, the key motivations for developing PagedAttention were:
Internal Memory Fragmentation: When reserving a contiguous memory chunk to store the KV cache for a request, a significant portion of this allocation may never be used if the actual sequence length is shorter than the maximum anticipated length. This unutilized memory is wasted and cannot be used for other requests.
External Memory Fragmentation: Even if the sequence length is known in advance, the gradual consumption of memory during the generation process means that shorter requests cannot reuse the still-unused memory chunks from longer requests.
Duplication of KV Entries: In decoding strategies like beam search, where multiple candidate sequences are generated per request, there is often redundancy in the KV cache entries across different candidates, leading to further memory waste.
To overcome these inefficiencies, the key ideas behind PagedAttention were:
Fixed-size Blocks: Instead of allocating a large contiguous chunk, PagedAttention allocates fixed-size and relatively small memory blocks called "pages" to store the KV cache. Each page can contain a fixed number of tokens.
Shared Blocks: These fixed-size blocks can be shared across different requests, alleviating internal fragmentation since any unused space in a block can be utilized by other requests.
On-demand Allocation: Blocks are allocated on-demand as the generation process progresses, eliminating the need for upfront allocation based on maximum sequence length estimates.
By using these techniques, PagedAttention achieves near-zero memory waste (typically less than 4%), allowing more requests to fit in the available memory and thereby increasing throughput. The small block size and on-demand allocation also help mitigate external fragmentation issues.
PagedAttention was first implemented by the vLLM inference system but is now supported by all the major inference frameworks (e.g. HuggingFace TGI, NVIDIA TensorRT-LLM, LMDeploy TurboMind, etc.).
Add some Salt of Radix Attention
RadixAttention: RadixAttention is a technique that enables efficient reuse of the key-value (KV) cache across different inference requests, particularly in scenarios where multiple requests share a common prefix, such as in multi-turn conversations or when using prompt templates.
The main idea behind RadixAttention is to store the KV cache tensors in GPU memory after completing a request and map them to the corresponding token sequence using a data structure called a radix tree. When a new request arrives, the scheduler checks for prefix matches in the radix tree and, if found, reuses the cached KV tensors instead of recomputing them from scratch.
RadixAttention and PagedAttention are complementary techniques that operate at different levels:
PagedAttention is a memory management technique that addresses inefficiencies in allocating and utilizing memory for the KV cache within a single request. It focuses on reducing internal and external memory fragmentation by using fixed-size blocks (pages) and on-demand allocation.
RadixAttention, on the other hand, operates at the model server level and focuses on reusing the KV cache across multiple requests, leveraging the common prefixes that often exist in scenarios like multi-turn conversations or prompt templates.
Surface Taps to open and let tokens flow
Temperature Parameter: So, the temperature parameter is all about relaxing the probability distribution mass to become uniform and allow less likely tokens to surface as a choice candidate. Temperature is a crucial hyperparameter in fine-tuning large language models (LLMs) like GPT-3. It controls the randomness and creativity of generated text by adjusting the probabilities associated with each word in the dictionary. These probabilities are generated in the last layer of the model using the softmax function, which transforms logits (raw output values) into probabilities.
The temperature hyperparameter modifies the logits before applying the softmax function. A lower temperature skews the probabilities towards extreme values, making the model more deterministic and less random. On the other hand, a higher temperature results in more evenly distributed probabilities, increasing randomness and creativity in the generated text.
By tuning the temperature, researchers and developers can strike a balance between coherence and creativity, allowing LLMs to produce output that meets specific requirements. This hyperparameter plays a vital role in controlling the characteristics of the generated text, making it an essential tool in fine-tuning these powerful language models.
As the above image shows with increased temperature parameter value the logit outputs are getting uniformly distributed.
Top_p and Top_k: These are selection strategies which work on the concept of cumulative probability distribution and k top probability values. setting top_p=0.95 means, selecting the token from the cumulative 0.95 probability value and discarding the rest 0.05. Similarly, top_k = 3 means, selecting the top 3 words in decreasing order of probability distribution.
Frequency penalty : The frequency penalty hyperparameter affects the probability of output tokens based on how often they appear in the generated output. The higher the frequency penalty, the less likely a word will reappear. This can be useful in situations where you want variety while still focusing on a specific topic. The frequency penalty ranges from -2 to 2, with a default of 0. The more times a token appears, the more it's penalized. For example, a higher frequency penalty might lead to fewer repeated words, greater diversity, and words that don't start with a specific letter. Reasonable values for the frequency penalty are around 0.1 to 1.
vLLM vs Llamacpp - which ecosystem to choose
https://www.reddit.com/r/LocalLLaMA/comments/18g21af/vllm_vs_llamacpp/
Ending Note :
When making memory-aware decisions for LLM inference, it is crucial to consider the trade-offs between memory footprint, computational efficiency, and model accuracy. Techniques like Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) can significantly reduce the memory footprint of the key-value (KV) cache, but they may come at the cost of reduced model capacity and potential accuracy loss, especially for larger models with many attention heads. Quantization techniques, such as LLM.int8() or SmoothQuant, can further compress the KV cache, but they introduce additional overhead and potential accuracy degradation. Memory management techniques like PagedAttention and RadixAttention can mitigate memory fragmentation and enable KV cache reuse across requests, leading to improved throughput and latency. Ultimately, the choice of techniques should be driven by the specific requirements of the use case, balancing memory constraints, performance needs, and the desired level of model accuracy.
Resources
https://medium.com/@plienhar/llm-inference-series-2-the-two-phase-process-behind-llms-responses-1ff1ff021cd5
https://www.reddit.com/r/LocalLLaMA/comments/18g21af/vllm_vs_llamacpp/
https://medium.com/@daniel.puenteviejo/the-science-of-control-how-temperature-top-p-and-top-k-shape-large-language-models-853cb0480dae
https://arxiv.org/abs/2310.01801
https://arxiv.org/abs/1911.02150