The world’s leading publication for data science, AI, and ML professionals.

Extending Context Length in Large Language Models

How to turn your Llama into a Giraffe

Image by the author. (AI generated Llamas)
Image by the author. (AI generated Llamas)

Context length refers to the maximum number of tokens the model can remember when generating text. A longer context window allows the model to understand long-range dependencies in text better. Models with longer contexts can build connections between ideas far apart in the text, generating more globally coherent outputs.

During training, the model processes the text data in chunks or fixed-length windows. Models need to be trained on lengthy texts to actually leverage long contexts. Training sequences must contain documents, books, articles, etc., with thousands of tokens. The length of training data sets a limit on usable context length.

So, why don’t we train models on longer sequences?

Not so fast.

Increasing context length increases the number of possible token combinations the model must learn to predict accurately. This enables more robust long-range modeling but also require more memory and processing power, leading to higher training costs.

Without any optimization, computation scales quadratically with context length – meaning that a 4096 token model will need 64 times more computation than a 512 token model.

You can use sparse or approximate attention methods to reduce the computation cost, but they may also affect the model’s accuracy.

Training and using large context language models presents three main challenges:

  • Fitting long contexts into the model.
  • Accelerating inference and training so they don’t take forever.
  • Ensuring a high-quality inference that maintains awareness of the full context.

Attention is a complex operation

The attention mechanism is the core component of transformer models. It relates different positions of a sequence to compute its representation, allowing models to focus on relevant parts of the text and understand it better. Scaling transformers to longer sequences faces challenges due to the quadratic complexity of full attention.

There are two matrix multiplications involved in self-attention. Image based on the original paper. [1]
There are two matrix multiplications involved in self-attention. Image based on the original paper. [1]

Stacked self-attention layers allow modeling long-range dependencies in text. The standard attention mechanism used in Transformers, which computes the attention weights for all possible pairs of input tokens, has a complexity of O(n²). It means that the computation and memory requirements grow quadratically with the input sequence length, limiting Transformers’ scalability and efficiency. When generating text, the model has to compute the attention matrix first. With a 100K context and quadratic attention, it can take minutes before the model starts generating text.

Let’s explore methods to improve attention efficiency, from approximations to hardware optimizations.

Improving Attention Efficiency

Reducing the quadratic cost has become an active research area. Proposed methods can be grouped into two main categories: approximating attention and exact attention using hardware-aware optimizations.

Approximation techniques constrain interactions between sequence positions. Sparse attention limits the number of non-zero attention weights per attention head, while local attention restricts interactions to a sliding window. These approximations reduce computational cost but may degrade accuracy on complex tasks. [2]

Recent work has focused on optimizing attention to leverage GPU architectures.

Sparse attention approximates attention by only computing the attention weights for a subset of the input tokens instead of all possible pairs, thus saving time and memory. There are different ways to implement sparse attention, such as using fixed or static patterns (e.g., local, strided, or block attention) or dynamic or adaptive patterns that depend on the input sequence (e.g., entmax or dynamic sparse attention).

Quadratic attention (left) computes every possible combination between input tokens. Sparse attention (right) limits the computation only to nearby tokens. [2]
Quadratic attention (left) computes every possible combination between input tokens. Sparse attention (right) limits the computation only to nearby tokens. [2]

Sparse attention can improve the efficiency and scalability of Transformers, especially for long sequences, but it may also sacrifice some representation power and accuracy. Quadratic attention can achieve high performance and quality, but it may also be computationally expensive and impractical for large-scale applications. Therefore, there is a trade-off between sparsity and complexity in attention mechanisms.

Flash Attention

The fundamental intuition is to avoid materializing the large N x N attention matrix, which requires quadratic reading/writing in the sequence length N.

FlashAttention applies two techniques – tiling and recomputation. Tiling splits the input into blocks, loaded into fast GPU on-chip SRAM. Attention is computed block-by-block to avoid materializing the entire matrix. Recomputation stores just enough information to reconstruct the attention matrix on-chip during backpropagation, avoiding storing the large intermediate. [3]

The authors analyze the IO complexity, proving FlashAttention requires O(N²/M) memory accesses versus O(N²) for standard attention, where M is the SRAM size. This IO-awareness allows FlashAttention to run faster despite increased FLOPs from recomputation.

Experiments validate the speedups – FlashAttention trains BERT 15% faster than the MLPerf record, GPT-2 3x faster, and Long Range Arena 2.4x faster.

Instead of computing the whole attention matrix on the slower HBM, FlashAttention copies blocks to the SRAM. [3]
Instead of computing the whole attention matrix on the slower HBM, FlashAttention copies blocks to the SRAM. [3]

This idea was further developer in FlashAttention-2. The improvements focus on enhancing parallelism across sequence blocks and optimizing work partitioning between thread blocks and warps on GPUs. Key techniques include reducing non-matrix multiply operations, partitioning attention computation across threads to increase occupancy, and distributing work between warps to reduce shared memory traffic. Empirical validation shows FlashAttention-2 achieves around 2x speedup over FlashAttention, reaching up to 73% of theoretical peak FLOPs on A100 GPUs. When used to train GPT models end-to-end, training throughput reaches 225 TFLOPs/s per A100, translating to 1.3x faster training than FlashAttention. [4]

The improvements promise to enable training models on much longer sequences than before at a similar cost. Accelerating attention speeds up inference and training, but fitting text into the model while maintaining high output quality remains an issue.

Let’s see what to do about it.

Models are pre-trained on fixed-length sequences

An efficient training and inference is not enough to have an high-quality model. There are two main paradigms of context length extension: fine-tuned extrapolation, where the LLM further updates its weights on longer contexts, and zero-shot extrapolation, where the model is evaluated on long contexts with no change to weights from the short context training. [5] To extend context, most approaches focus on modifying the positional encoding system used in the transformer attention mechanism to indicate where tokens are located in the input sequence. The idea is that representing longer input sequences in the positional encoding will allow the LLM to attend to those longer sequences.

Positional encoding is used to make your model understand the order in the sentence. Image based on the original paper. [1]
Positional encoding is used to make your model understand the order in the sentence. Image based on the original paper. [1]

Positional encoding is used to make your model understand the order in the sentence.

Positional embeddings are added to the input token embeddings before feeding them into the model to enable the model to use the order of the sequence. They map the discrete positional IDs to continuous embedding vectors.

Usually, positional embeddings are defined algorithmically based on the position IDs. In the original Transformers paper, they used a trigonometric function, where each dimension of the positional embedding vector follows a sinusoidal pattern. [1]

In LLaMa, Rotary Position Embeddings (RoPE) are used, where positional embeddings are computed on the fly using rotary embeddings. The token and positional dimensions are rotated together using trigonometric functions. The rotation amounts are determined by the position ID.

Regardless of how positional embeddings are generated, models struggle to generalize to sequences longer than what was seen during pretraining. (context extrapolation) Sinusoidal position embedding methods have limited extrapolation ability, only allowing for a few dozen more tokens during inference before performance degrades. [6]

Newer approaches like linear scaling and position interpolation have been introduced to address this limitation.

Linear scaling

With linear scaling, the positional embeddings are rescaled to adapt the model to different sequence lengths. If the pre-trained model has embeddings up to length L, then for inference on a sequence of length N, each positional embedding vector is multiplied by N/L. This cheaply approximates embeddings for the new length while retaining the pre-trained embedding properties. Linear scaling improves performance on long sequences significantly. However, the model still underperforms on sequences much longer than the pretrained length. The linear scaling process destroys information by collapsing multiple position embeddings together.

In position interpolation, the context range is not extended. Instead, there are more intermediate positions. [7]
In position interpolation, the context range is not extended. Instead, there are more intermediate positions. [7]

Linear scaling/interpolation performs best overall for extending context, with promise also shown in the truncated basis method. Further gains are achieved by using a longer scaling factor at evaluation time. This method was concurrently researched by kaiokendev and Meta. [8] The first released YaRN, a model able to achieve 128k context length. Later, three more fine-tuned models with 8k, 16k and 32k context were released, called Giraffe. [5]

Attention with Linear Biases

ALiBi introduces a simpler approach that eliminates positional embeddings. Instead, it negatively biases attention scores between queries and keys with a penalty proportional to their distance. This inductive bias towards recent contexts allows extrapolation at low computational cost. Experiments show a 1.3B parameter model trained on 1024 tokens with ALiBi achieving the same perplexity as a sinusoidal model trained on 2048 tokens when tested on 2048 tokens, training faster and using less memory. [6]

MosaicML’s MPT-7B model leverages the ALiBi architecture to enable extrapolation to extreme context lengths up to 65k tokens, far surpassing the limits of other open-source models. By replacing positional embeddings with ALiBi, the model gains the ability to handle inputs of arbitrary length during inference without being constrained to a fixed context window. This was demonstrated through fine-tuning MPT-7B into MPT-7B-StoryWriter-65k+ using 65k token excerpts from fiction books, allowing it to generate coherent continuations from the full text of The Great Gatsby at 68k tokens.

The choice of positional embedding approach depends on considerations like model size, expected sequence lengths, and how important generalization is for the problem.

There remains significant room for improvement, as all methods degrade in accuracy with increasing length, even when perplexity is reasonable.

Bonus model: RWKV

Another way to improve attention would be to not use it.

The Receptance Weighted Key Value (RWKV) model introduced by Peng et al. aims to reconcile the trade-off between computational efficiency and model performance in sequence processing tasks. RWKV combines aspects of both Transformers and RNNs into a novel architecture that achieves linear scaling. [9] A key innovation is the reformulation of the attention mechanism to use scalar interactions rather than dot products, eliminating the quadratic bottleneck.

RWKV architecture for language modeling. [9]
RWKV architecture for language modeling. [9]

RWKV implements a variant of linear attention without approximation for improved efficiency. The model parallelizes computations in training like Transformers but behaves as an RNN decoder at inference, yielding constant speed and memory with unlimited context. Experiments demonstrate RWKV is competitive with Transformers on language tasks while requiring lower computational cost. However, given its recurrent nature, it might be necessary to adapt the prompts carefully and it might struggle with maintaining detailed tracking over extremely long sequences.

Raven is an example of an RWKV model. This model competes with the smallest Llama-based models. In my tests, Raven showed a good understanding of grammar and semantic meaning, but tended to hallucinate often.

Conclusion

Language models benefit from longer contexts. However, longer contexts increase training costs quadratically due to the standard attention mechanism. Recent research focuses on approximating attention to improve efficiency. Methods like sparse attention and linear attention help. Optimizing hardware efficiency also works, as FlashAttention shows by leveraging GPU memory hierarchies.

Pretrained models still struggle with contexts longer than those seen during training. Techniques like linear scaling of positional embeddings and ALiBi enable longer contexts. Fine-tuning on longer contexts further adapts models.

State-of-the-art models push context lengths far beyond previous limits. YaRN and Giraffe use position iterpolation. MPT-65k uses ALiBi for 65,000 token contexts. RWKV proposes linear-scaling attention, allowing unlimited context at inference.

Longer contexts empower models to process full documents and books, but challenges remain in maintaining output quality over extended lengths.


If you enjoyed this article, join Text Generation – our newsletter has two weekly posts with the latest insights on Generative AI and Large Language Models.

Also, you can find me on LinkedIn.


References

[1] [1706.03762] Attention Is All You Need (arxiv.org) [2] [1904.10509] Generating Long Sequences with Sparse Transformers (arxiv.org) [3] [2205.14135] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (arxiv.org) [4] [2307.08691] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (arxiv.org) [5] [2308.10882] Giraffe: Adventures in Expanding Context Lengths in LLMs (arxiv.org) [6] [2108.12409] Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation (arxiv.org) [7] [2306.15595] Extending Context Window of Large Language Models via Positional Interpolation (arxiv.org) [8] [2309.00071] YaRN: Efficient Context Window Extension of Large Language Models (arxiv.org) [9] [2305.13048] RWKV: Reinventing RNNs for the Transformer Era (arxiv.org)


Related Articles