Tuesday, January 14, 2025

Inference-Pleasant Fashions with MixAttention | Databricks Weblog


Transformer fashions, the spine of recent language AI, depend on the eye mechanism to course of context when producing output. Throughout inference, the eye mechanism works by computing the important thing and worth vectors for every token seen thus far, and utilizing these vectors to replace the interior illustration of the subsequent token which will probably be output. As a result of the identical key and worth vectors of the previous tokens get reused each time the mannequin outputs a brand new token, it’s commonplace apply to cache it in an information construction referred to as the Key-Worth (KV) cache. Because the KV cache grows proportionally to the variety of tokens seen thus far, KV cache measurement is a significant component in figuring out each the utmost context size (i.e., the utmost variety of tokens) and the utmost variety of concurrent requests that may be supported for inference on fashionable language fashions. Significantly for lengthy inputs, LLM inference can be dominated by the I/O price of shifting the KV cache from Excessive Bandwidth Reminiscence (HBM) to the GPU’s shared reminiscence. Subsequently, lowering the KV cache measurement has the potential to be a robust technique to hurry up and scale back the price of inference on fashionable language fashions. On this publish, we discover concepts not too long ago proposed by Character.AI for lowering KV cache measurement by changing a lot of the layers within the community with sliding window consideration (a type of native consideration that solely makes use of the important thing and worth vectors of a small variety of most up-to-date tokens) and sharing the KV cache amongst layers. We name this structure MixAttention; our experiments with totally different variants of this structure have demonstrated that it maintains each brief and lengthy context mannequin high quality whereas enhancing the inference pace and reminiscence footprint.

Determine 1: Velocity and accuracy of MixAttention mannequin variants. (Mannequin variants proven in determine 2). High: We see that MixAttention fashions are quicker and use much less reminiscence throughout inference at 32K context size. Backside: MixAttention fashions preserve high quality – they match the usual consideration mannequin on most evals. The fashions are all Combination of Specialists with 2B energetic and 5B whole parameters.

We discovered that KV cache sharing between layers and including sliding window layers can pace up inference and scale back inference reminiscence utilization whereas sustaining mannequin high quality, though some eval metrics present some degradation. As well as, our ablation experiments confirmed the next:

 

  • Having a couple of commonplace consideration layers is essential for the mannequin’s lengthy context talents. Particularly, having the usual KV cache computed within the deeper layers is extra essential for lengthy context talents than the usual KV cache of the primary few layers.
  • KV cache of ordinary consideration layers might be shared between non-consecutive layers with none noticed degradation in lengthy context talents.
  • Rising the KV-cache sharing between sliding window layers an excessive amount of additionally hurts lengthy context talents. 

We’ve supplied a information to configuring and coaching MixAttention fashions utilizing LLM Foundry within the appendix of this weblog publish.

 

Image 2 Mix Attention Blog
Determine 2: (Left) A normal transformer mannequin the place all layers are commonplace consideration layers. (Center) Inference-friendly fashions with MixAttention. Inexperienced bars symbolize sliding window consideration and the traces connecting bars symbolize KV cache sharing. (Proper) A mannequin the place all layers are sliding window consideration.

MixAttention Structure Overview

Normal transformer fashions use world consideration in every layer. To create inference-friendly mannequin architectures, we used a mix of sliding window consideration layers, commonplace consideration, and KV cache reuse layers. Under is a short dialogue of every element:

 

  • Sliding Window Consideration Layers: In Sliding Window Consideration (or Native Consideration) with window measurement s, the question solely pays consideration to the final s keys as a substitute of all of the keys previous it. Which means throughout inference, the KV cache measurement must solely retailer the KV tensors for the previous s tokens as a substitute of storing the KV tensors for all of the previous tokens. In our experiments, we set a window measurement of s=1024 tokens.
  • Normal Consideration Layers: We discovered that despite the fact that Normal Consideration Layers result in greater KV caches and slower consideration computation in comparison with Sliding Window Consideration, having a couple of Normal Consideration Layers is essential for the mannequin’s lengthy context talents.
  • KV cache reuse: This refers to a layer within the transformer community that’s reusing the KV cache computed by a earlier layer. Therefore, if each l layers share KV tensors, then the scale of KV cache is lowered by issue of 1/l.

 

We experimented with totally different mixtures of the elements above to ablate the results of every of them. (Extra mixtures are described within the appendices.) We discovered that not solely do every of the above elements play essential roles in lengthy context talents and inference pace and reminiscence consumption, but in addition their relative positions and counts have important results on these metrics.

 

The fashions we skilled are 24-layer Combination of Specialists (MoE) fashions with 1.64B energetic and 5.21B whole parameters. We used RoPE positional embeddings and elevated the RoPE base theta as we elevated the context size throughout coaching. We used Grouped Question Consideration with 12 consideration heads and three KV heads.

 

Coaching

We used LLM Foundry to coach MixAttention fashions. Much like prior work on coaching lengthy context fashions, we adopted a multi-stage coaching process to impart lengthy context talents to the fashions.

 

  1. We pretrained the fashions with a RoPE theta of 0.5M on 101B tokens, the place every sequence has been truncated to 4k token size.
  2. To extend the context size, we then skilled the mannequin on 9B tokens from a mixture of pure language and code information, the place the sequences have been truncated to 32k tokens. We elevated the RoPE theta to 8M for this stage. When coaching at 32k context size, we skilled solely the eye weights and froze the remainder of the community. We discovered that this delivered higher outcomes than full community coaching.
  3. Lastly, we skilled the mannequin on a 32k-length, artificial, long-context QA dataset.
    • To create the dataset, we took pure language paperwork and chunked them into 1k-token chunks. Every chunk was then fed to a pretrained instruction mannequin and the mannequin was prompted to generate a question-answer pair primarily based on the chunk. Then, we concatenated chunks from totally different paperwork collectively to function the “lengthy context.” On the finish of this lengthy context, the question-answer pairs for every of the chunks have been added. The loss gradients have been computed solely on the reply elements of those sequences. 
    • This part of coaching was carried out on 500M tokens (this quantity consists of the tokens from the context, questions, and solutions). The RoPE theta was stored at 8M for this stage.

Analysis

The fashions have been evaluated on the Mosaic Analysis Gauntlet to measure mannequin high quality throughout varied metrics together with studying comprehension, commonsense reasoning, world data, symbolic drawback fixing, and language understanding. To judge the fashions’ lengthy context talents, we used RULER at a context size of 32000 tokens. RULER is a composite benchmark consisting of 13 particular person evals of the next sorts:

 

  • Needle-in-a-haystack (NIAH): All these evals cover a single or a number of keys and values in an extended textual content, and the mannequin is evaluated on its capability to retrieve the proper worth(s) from the lengthy context for a given key(s).
  • Variable Monitoring (VT): This eval offers the mannequin with an extended context containing variable task statements, and the mannequin is tasked to determine which variables have a selected worth by the top of all of the variable assignments.
  • Widespread and Frequent Phrase Extraction (CWE and FWE): These duties ask the mannequin to extract the commonest or frequent phrases from the textual content.
  • Query Answering (QA): Given an extended context, the mannequin is requested a query from someplace within the context and is evaluated on whether or not it could accurately reply that query.

 

We used SGLang to deploy our fashions on 1 NVIDIA H100 GPU to run RULER and get inference pace and reminiscence consumption metrics.

Outcomes

Place and Depend of Normal Consideration KV Caches

To measure the impact of the place and rely of the usual consideration KV caches, we tried 4 variants. All of the configurations are variants of the configuration proposed in Character.AI’s weblog publish.

MixAttention Image 3
Determine 3: KV Cache place and counts. To measure the impact of the place and rely of the usual consideration KV caches on MixAttention’s lengthy context talents, we skilled and evaluated the 4 fashions proven above.
  1. MA: This variant has a single commonplace consideration KV cache, which is the KV cache of the primary layer. All the opposite commonplace consideration layers share this KV cache.
  2. MA-EndSlide: This variant is similar as MA, however the final layer is a sliding window consideration layer. This was finished to measure how a lot having commonplace consideration within the final layer impacts long-context talents.
  3. MA-Offset: This variant is much like MA, however the first commonplace consideration layer is offset to a later layer to permit the mannequin to course of the native context for a couple of layers earlier than the usual consideration layer is used to take a look at longer contexts.
  4. MA-Pairs: This variant computes two commonplace consideration KV caches (on the first and thirteenth layers), that are then shared with one other commonplace consideration layer every.

We in contrast these fashions to a transformer mannequin with Normal Consideration and a transformer mannequin with Sliding Window Consideration in all layers.

MixAttention image 4

MixAttention_image5
Fig. 4 and 5: Impact of Normal Consideration Layers. (High) Loss curves of the fashions when high-quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. MA and MA-EndSlide carry out poorly on lengthy context duties whereas MA-Offset and MA-Pairs carry out properly. This means that having a normal consideration KV cache which is computed in later layers is essential for lengthy context talents. We additionally discovered that the loss on lengthy context QA dataset correlates properly with the mannequin’s lengthy context talents.

Whereas the loss curves in Levels 1 and a pair of of Coaching have been shut for all of the fashions, we discovered that in Stage 3 (coaching on lengthy context QA dataset), there was a transparent bifurcation within the loss curves. Particularly, we see that configurations MA and MA-EndSlide present a lot worse loss than the others. These outcomes are in keeping with the lengthy context RULER evals, the place we discovered that MA and MA-EndSlide carried out a lot worse than others. Their efficiency was much like the efficiency of the community with solely sliding window consideration in all layers. We expect the loss in Stage 3 correlates properly with RULER evals as a result of not like Levels 1 and a pair of, which have been next-word prediction duties the place native context was enough to foretell the subsequent phrase more often than not, in Stage 3 the mannequin wanted to retrieve the proper data from doubtlessly long-distance context to reply the questions. 

 

As we see from the RULER evals, MA-Offset and MA-Pairs have higher long-context talents than MA and MA-EndSlide throughout all of the classes. Each MA and MA-EndSlide have just one commonplace consideration KV cache, which is computed within the first layer, whereas each MA-Offset and MA-Pairs have not less than one commonplace consideration KV cache which is computed in deeper layers.  Therefore, this means that having not less than one commonplace consideration KV cache computed within the deeper layers of a transformer mannequin is important for good long-context talents.

KV cache sharing in sliding window layers

MixAttention Image 6
Fig. 6: Rising KV cache sharing in sliding window layers. To measure the impact of KV cache sharing within the sliding window layers, we in contrast the architectures proven within the determine above.

Mix Attention Image 7

Mix Attention Image 8
Fig. 7 and eight: Impact of accelerating KV cache sharing in sliding window layers. (High) Loss curves of the fashions when high-quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that rising the KV cache sharing in sliding window layers worsened lengthy context talents of MixAttention Fashions.

We discovered that rising the sharing between sliding window layers degraded the mannequin’s lengthy context efficiency: MA-Offset-slide-share was worse than MA-Offset and MA-Pairs-SlideShare was worse than MA-Pairs. This exhibits that the KV cache sharing sample amongst the sliding window layers can be essential for lengthy context talents.

 

We’ve supplied the outcomes of some extra ablation experiments within the appendices.

Gauntlet Evals

Utilizing the Mosaic Eval Gauntlet v0.3.0, we additionally measured the efficiency of MixAttention fashions on commonplace duties like MMLU, HellaSwag, and so on. to confirm that they preserve good shorter context talents. The entire duties on this eval suite have context lengths of lower than a couple of thousand tokens.

MixAttention Figure 9
Fig. 9: Efficiency of MixAttention fashions on the Eval Gauntlet. We discovered that MixAttention fashions have related eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world data. Nonetheless, we see that they carry out worse on studying comprehension.

We discovered that MixAttention fashions have related eval metrics to the baseline mannequin on commonsense reasoning, language understanding, and world data; nevertheless, they carried out worse on studying comprehension. An fascinating open query is that if studying comprehension talents may very well be improved with a special MixAttention configuration or by coaching MixAttention fashions longer.

Inference Velocity and Reminiscence Consumption

Mix Attention Image 10

MixAttention Image 11
Fig. 10 and 11: (High) MixAttention fashions have considerably quicker inference than commonplace transformers. (Backside) MixAttention fashions can assist extra tokens, and thus bigger batch sizes, throughout inference.

We benchmarked the inference pace and reminiscence consumption of MixAttention fashions by deploying them on a single NVIDIA H100 GPU utilizing SGLang and querying them with 300 prompts, with an enter size of 31000 and output size of 1000. Within the determine, we present that the inference pace of MixAttention fashions is way quicker than commonplace consideration fashions. We additionally present that with MixAttention, we will assist a a lot bigger inference batch measurement when it comes to the full variety of tokens. 

 

We discovered that the present implementation of Sliding Window Consideration in SGLang doesn’t optimize the reminiscence consumption for sliding window consideration; therefore, sliding window consideration has the identical most variety of tokens as the usual consideration Mannequin. Optimizing the reminiscence consumption for sliding window consideration ought to additional enhance the utmost variety of tokens that MixAttention can assist throughout inference.

Conclusion

We discovered that MixAttention fashions are aggressive with commonplace consideration fashions on each long- and short-context talents whereas being quicker throughout inference and supporting bigger batch sizes. We additionally noticed that on some lengthy context duties like Variable Monitoring and Widespread Phrase Extraction, neither MixAttention nor commonplace consideration fashions carried out properly. We imagine this was as a result of our fashions weren’t skilled lengthy sufficient or the fashions want a special sort of lengthy context information to be skilled for such duties. Extra analysis must be finished to measure the impression of MixAttention architectures on these metrics.

 

We encourage others to discover extra MixAttention architectures to study extra about them. Under are a couple of observations to assist with additional analysis:

 

  • Including a normal consideration layer within the preliminary layers by itself doesn’t appear to assist lengthy context talents (for instance, see MA-NoShare-1 within the appendix), even when the KV cache from that layer is reused in layers deeper into the community (MA and MA-EndSlide). Therefore we suggest inserting the primary commonplace consideration layer deeper within the community (like MA-Offset) or having a number of commonplace consideration layers, not less than one in every of which is computed at a deeper layer (like MA-Pairs).
  • Sliding window layers additionally contribute to the mannequin’s lengthy context talents. Rising the KV cache sharing amongst the sliding window layers worsened lengthy context talents (MA-Offset-SlideShare and MA-Pairs-SlideShare). For that purpose, we predict that the 2-3 sharing sample in sliding window layers appears to strike a great steadiness.
  • Sharing full consideration KV caches between consecutive layers gave combined outcomes, with barely worse accuracy on lengthy context QA duties (see the appendix). 
  • In our experiments, MA-Offset and MA-Pair confirmed nice speedup and reminiscence financial savings throughout inference, whereas additionally sustaining lengthy and brief context talents. Therefore, MA-Offset and MA-Pairs is perhaps good configurations for additional analysis.
  • MixAttention fashions might be skilled with LLM Foundry. Please see the appendix for tips.

 

Usually, there’s a giant hyperparameter area to discover, and we stay up for seeing a wide range of new methods for lowering the price of inference through mixtures of sliding window consideration and KV cache reuse.

Appendix: Utilizing LLM Foundry to coach MixAttention fashions

The way in which to configure MixAttention fashions with LLM Foundry is to make use of the block_overrides characteristic. The block_overrides definition consists of two sections: order and overrides. The order key defines the ordering and the names of the layers within the community, whereas the overrides key incorporates the customized configuration of every named layer. 

 

For instance, to create a 5 layer community with the primary two layers being the usual consideration layers, the subsequent two being the sliding window layers, and the final one being a normal consideration layer, we use the next YAML:

CodeSnippet1

Right here, the order part conveys that the primary two layers are of kind ‘default’, the subsequent two are of kind ‘sliding_window_layer’, and the final is of kind ‘default’ once more. The definitions of every of those sorts are contained within the overrides part utilizing the names outlined within the order part. It says that the ‘sliding_window_layer ought to have a sliding_window_size of 1024. Observe that ‘default’ is a particular kind, which doesn’t want a definition within the overrides part as a result of it simply refers back to the default layer (on this case, a normal consideration layer). Additionally, be aware that ‘sliding_window_layer‘ is only a customized identify and might be changed with some other arbitrary identify so long as that identify is correspondingly additionally outlined within the overrides part.

 

The mannequin configuration is printed within the logs, which can be utilized to verify that the mannequin is configured accurately. For instance, the above YAML will end result within the following being printed within the logs:

CodeSnippet2

We will additionally configure the 2 sliding window layers to have totally different sliding window sizes as follows:

CodeSnippet3

The above will end result within the third layer having a sliding window measurement of 1024, and the fourth layer having a sliding window measurement of 512. Observe that the repeat key phrase defaults to 1. So, the above YAML can be written as:

CodeSnippet4

The repeat key phrase can be relevant to the order key phrase. So, if we need to create a 4 layer community with alternating commonplace and sliding window consideration layers like the next,

MixAttention Appendix 1

then we will use the next YAML:

CodeSnippet5

To make a layer reuse the KV cache of a earlier layer, we use reuse_kv_layer_idx within the attn_config within the override definition. The important thing reuse_kv_layer_idx incorporates the relative layer index whose KV cache we would like this layer to reuse. To make a two layered community the place the second layer reuses the primary layer’s KV cache, we will use the next YAML:

CodeSnippet6

The worth -1 signifies that the layer named kv_reuse_layer reuses the KV cache of the layer that’s one layer earlier than it. To create a 5 layer community with the next configuration

Mix Attention Appendix Image 2

we will use the next YAML:

CodeSnippet7

Observe that within the above configuration, layer #4 reuses the KV cache of layer #3, which in flip reuses the KV cache of layer #2. Therefore, layer #4 finally ends up reusing the KV cache of layer #2.

 

Lastly, be aware that order might be outlined recursively; that’s, the order can include one other order sub-block. For instance, MA-Offset-SlideShare

Appendix 3 image

might be outlined as follows:

CodeSnippet8

Appendix: Different Ablation Experiments

Sharing Normal Consideration KV Caches between Consecutive Layers

Because the transformer layers progressively replace the latent illustration of a token because it progresses via the layers, the Question, Key, and Worth tensors may need considerably totally different representations for layers which might be far aside. Therefore, it’d make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 in opposition to MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of ordinary KV consideration layers.

MixAttention image 4
KV cache sharing between consecutive layers: To measure the impact of KV cache sharing between consecutive layers, we tried the 4 configurations above.

 

Because the transformer layers progressively replace the latent illustration of a token because it progresses via the layers, the Question, Key, and Worth tensors may need considerably totally different representations for layers which might be far aside. Therefore, it’d make extra sense to share KV caches between consecutive layers. To check this, we in contrast 4 such configurations: MA-Successive-1, MA-Successive-2, MA-Successive-3, and MA-Successive-4 in opposition to MA-Pairs. These configurations differ the positions of the usual KV consideration layers and the space between the consecutive pairs of ordinary KV consideration layers.

MixAttention appendix 5

MixAttention appendix 6
Impact of KV cache sharing between consecutive layers: (High) Loss curves of the fashions when high-quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that KV cache sharing between consecutive layers doesn’t persistently enhance lengthy context talents throughout all evals. Nonetheless, for duties like  SQuAD QA and Hotpot QA, which might be indicative of lengthy context RAG talents, the efficiency was barely worse when sharing KV cache between consecutive layers.

We decided that each one the fashions have related loss curves and related efficiency on NIAH single 1, 2, and three duties, which we contemplate to be the best lengthy context duties. Nonetheless, we didn’t see a constant sample throughout the opposite NIAH duties. For lengthy context QA duties, we discovered that MA-Pairs was barely higher than the others. These outcomes point out that sharing commonplace consideration KV cache between layers which might be additional aside doesn’t result in any important degradation in lengthy context talents as in comparison with sharing commonplace consideration KV cache between consecutive layers.

Impact of Sharing Normal Consideration KV Cache

MixAttention appendix 7
No commonplace consideration KV-cache sharing: To measure the impact of KV cache sharing between commonplace consideration layers we evaluate the architectures proven within the determine above.

MixAttention appendix 8

MixAttention appendix 9
Impact of no commonplace consideration KV-cache sharing: (High) Loss curves of the fashions when high-quality tuning on lengthy context QA dataset. (Backside) RULER evals for the fashions. We discovered that each MA-NoShare-2 and MA-NoShare-3 have been comparable with MA-Offset.

 

To check the impact of sharing the KV cache between commonplace consideration layers, we tried out three configurations: MA-NoShare-1, MA-NoShare-2, and MA-NoShare-3. We discovered that MA-NoShare-1 carried out very badly on RULER, indicating its lack of lengthy context talents. Nonetheless, MA-NoShare-2 and MA-NoShare-3 have been similar to MA-Offset on lengthy context duties. Therefore, we predict that additional analysis is required to determine the results of sharing commonplace consideration KV cache.

Related Articles

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Latest Articles

PHP Code Snippets Powered By : XYZScripts.com