mni-ml logo mni-ml
module 4

LLM Inference

Improving token throughput at scale

Everything we’ve discussed to this point has been focused on training, how models learn from data. We have yet to discuss an important aspect of machine learning: inference.

Checkout our speculative decoding demo here using a smaller 2M parameter draft model to improve token throughput by up to 3x. To learn more about speculative decoding keep reading.

Inference refers to the process of actually running an already trained model, applying what it learned to generate actual outputs and make predictions. In the context of an LLM, this involves two stages: prefill and decode

Prefill

Prefill is the stage of inference that involves the model “reading” the input prompt. This stage determines how long it takes for the first token to appear, a duration referred to as time to first token (TTFT). One of the most important aspects of the prefill stage is the generation of the key-value cache.

Key-Value Cache (KV Cache)

As discussed in the transformers section of this blog, one of the most important parts of the architecture behind LLMs is the attention block. In this part of the model, query, key, and value matrices are obtained by multiplying the input token embeddings with learned weight matrices.

To predict the next token, we need to compute attention, which involves computing the key and value matrices from previously seen tokens. If we run the forward pass of a transformer naively, we recompute K and V for each previous token, resulting in a quadratic number of operations with respect to the sequence length.

The key-value cache addresses this by saving the previously computed results. We note that with each additional token generated, both the key and value matrices are identical, except they gain an additional row for the new token. We thus can cache the previous key and value matrices to avoid recomputing past states during decoding.

We store the key and value matrices for every layer in VRAM. The memory requirement of this is given by:

The memory requirements of the KV cache and further optimizations that can be made are discussed later.

To see a live demo of the difference that the KV cache makes, see here. The token throughput improves by more than 20 times with the use of the KV cache.

Optimizing Prefill

As mentioned, the prefill stage is crucial for the creation of the KV cache, enabling faster decoding. This process can be parallelized well because the full input sequence is immediately available. In this regard, the prefill stage is more demanding of raw GPU compute, i.e. floating point operations per second (FLOPS).

To enable faster inference at scale, numerous optimizations are made during the prefill stage.

Chunked Prefill

Chunked prefill is an optimization that involves breaking up a large input prompt into smaller pieces that are each processed as individual chunks. This prevents a large prompt from monopolizing GPU time and delaying other requests made, and enables more effective interleaving of decode work.

Prefix Caching

Prefix caching is another key optimization made. If numerous prompts have an identical prefix, e.g. “You are a helpful AI assistant…”, the key and value matrices across all layers have identical first few rows corresponding to these tokens. To prevent repeated computations, results can be cached for prefixes so that if an identical prefix occurs, the time to generate the KV cache can be significantly reduced.

Decode

The second half of inference involves the model actually sequentially predicting, or informally, “writing”, the next token. At this point, inference and its optimizations are more concerned with memory bandwidth constraints.

Paged Attention

As mentioned earlier, the KV cache is one of the most important optimizations in decoding, but it also introduces a major bottleneck: memory. As context lengths grow, storing the key and value matrices for every layer and every token becomes extremely expensive in VRAM.

A naive implementation would allocate one large contiguous chunk of memory for each request’s KV cache. This quickly becomes inefficient. Different requests have different sequence lengths, finish at different times, and grow dynamically as new tokens are generated. This can lead to fragmentation and poor memory utilization.

Paged attention addresses this by splitting memory into fixed-size blocks, often called pages. Rather than requiring a request’s KV cache to occupy one contiguous region of VRAM, its cache can be distributed across many blocks. This allows memory to be allocated and reused much more flexibly.

The memory requirement of one such block can be written as:

This approach is conceptually similar to paging in operating systems. The logical sequence of tokens remains the same, but the physical layout in memory no longer needs to be contiguous. This enables much more efficient KV cache management and allows more requests to be served simultaneously.

Continuous Batching

Another key optimization during decoding is continuous batching.

In a naive inference system, requests are grouped into a fixed batch and decoded together. The issue is that not all requests complete at the same time. Some may finish early, while others continue generating tokens. If the batch remains fixed, the GPU can end up underutilized as finished requests leave behind empty space in the batch.

Continuous batching solves this by treating a batch not as a fixed set of requests, but as a pool of available decoding capacity. When one request finishes, another can immediately take its place. In other words, incoming requests are continuously inserted into the batch as space becomes available.

This keeps the GPU busy and improves throughput significantly.

Paged attention is a major enabler of this design. Since the KV cache is stored in flexible memory blocks, blocks belonging to finished requests can be freed and reassigned to new ones efficiently. Together, paged attention and continuous batching form the basis of many modern inference engines.

Quantization

A further class of optimization is quantization.

Quantization refers to storing and computing with lower precision numbers so that models use less memory and often run faster. Instead of representing values with high precision floating point formats, we approximate them with lower precision formats such as 8-bit or 4-bit integers.

The central mathematical idea is that a real value can be represented approximately using a scale and zero point :

and then reconstructed approximately as:

The scale controls the tradeoff between precision and range. A smaller scale allows more precise representation over a smaller range of values, while a larger scale covers a wider range with less precision.

Quantization is especially important during decoding because decoding is heavily constrained by memory bandwidth. If weights or KV cache values can be stored in fewer bits, less data must be moved through memory, often improving throughput significantly.

KV cache quantization is particularly valuable, since the KV cache can become one of the dominant consumers of VRAM for long-context inference.

Multi-Query Attention and Grouped-Query Attention

The structure of attention itself can also be modified to reduce decoding costs.

In standard multi-head attention, every attention head has its own query, key, and value projections. This gives the model flexibility, but it also means that a large amount of key and value data must be stored in the KV cache.

Multi-query attention (MQA) reduces this overhead by allowing all query heads to share the same key and value heads. In other words, each head still has its own query, but the keys and values are shared.

This greatly reduces the size of the KV cache and the amount of memory bandwidth required during decoding.

However, sharing one set of keys and values across all heads may reduce model quality in some settings. This motivates grouped-query attention (GQA), which acts as a compromise.

In GQA, instead of every query head having its own key and value head, several query heads share one key-value head. This preserves more flexibility than MQA while still significantly reducing KV cache size and bandwidth requirements compared to full multi-head attention.

Because decoding is so sensitive to memory movement, these changes can have a major impact on inference performance.

Multi-Head Latent Attention

An even more aggressive optimization is multi-head latent attention (MLA).

Rather than caching the full keys and values for each token, MLA stores a smaller latent representation for each token and reconstructs the relevant key and value information when needed.

If the hidden representation of a token is , a latent vector can be formed using a learned projection matrix:

Instead of storing the full and , we cache this smaller latent vector. The idea is that this latent contains enough information to later reconstruct the parts of the key and value representations needed for attention.

The intuition is simple: rather than caching the full attention state for every token, cache a compressed version of it.

This can reduce memory usage substantially, which is especially important for long sequence lengths where KV cache storage becomes a major bottleneck.

Speculative Decoding

One of the most elegant optimizations in inference is speculative decoding.

The basic idea is to use a smaller and faster draft model to propose several future tokens, and then use the larger main model to verify them. If the draft model’s guesses are correct, the large model can accept multiple tokens at once rather than generating them one by one.

At first this may seem surprising. Since the smaller model has a different distribution from the larger one, it might appear that this would change the output distribution. However, speculative decoding includes a statistical correction step that makes the final distribution identical to that of the larger model. In this sense, it is a lossless optimization.

Suppose the draft model has distribution and the large model has distribution . A proposed token is accepted with probability:

If a token is rejected, we must sample from a corrected residual distribution:

This correction ensures that, despite using the smaller draft model to accelerate generation, the final sequence of tokens is distributed exactly as though we had sampled from the larger model alone.

The result is a speedup in decoding without sacrificing correctness.

Mixture of Experts

Another important architectural optimization is mixture of experts (MoE).

In a standard transformer block, each token passes through a dense feed-forward network, often called an MLP. In an MoE model, this dense feed-forward network is replaced by multiple separate expert networks.

A small routing network examines each token representation and decides which experts should process it. Rather than activating all experts, only a sparse subset is used for each token. The outputs of the selected experts are then combined, usually through a weighted sum.

The intuition is that instead of one large general-purpose feed-forward network, we have many specialists plus a dispatcher. This allows the model to have much greater total capacity while only using a fraction of that capacity on each forward pass.

This gives a favorable tradeoff: more parameters, but not proportional compute per token.

MoE also introduces systems challenges. Since different tokens may be routed to different experts, work must be balanced across GPUs efficiently. This is where expert parallelism becomes important, distributing experts across devices and coordinating communication so that no single expert becomes a bottleneck.

In this sense, MoE is not only a modeling idea, but also a systems and inference idea.

Disaggregated Prefill and Decode

This is the final optimization I’d like to discuss as it relates to inference as a whole. As we highlight, the prefill stage and the decode stage each have different resource demands. Prefill is more compute-intensive, decoding is more memory bandwidth-intensive.

Disaggregating the two involves running each process on different GPU resources. This enables a variety of optimizations including:

  • Independent scaling, allowing GPU resources to be allocated independently according to needs
  • Preventing interference, ensuring that decoding is blocked by prefill tasks
  • Hardware specialization, allow prefill to take place on GPUs with high compute and decoding to take place on GPUs with higher memory bandwidth There are drawbacks to this approach however. Mainly, transferring the KV cache between the GPUs running prefill and those running the decoding. For this reason specialized libraries like NVIDIA NIXL try to reduce this overhead by improving data transfer times.

Conclusion

Inference economics are very important. During the training phase of a model, your costs (time and monetary) are capped at the point where training ends. However, with inference, your costs (time/latency and monetary) are incurred with every request. Ensuring efficient utilization of resources is crucial for production ML systems at scale, and any optimizations that you can make will compound savings over time.

This post is just an introduction into the optimizations being made for inference. Advancements in the field are constantly being made, and there is much to explore beyond the scope of this blog. We do hope that everything we’ve covered up to this point has helped you gain a deeper understanding of machine learning, and aim to continue developing in the sphere.