AI Interview Series #4: Explain KV Caching

by
0 comments
AI Interview Series #4: Explain KV Caching

Question:

You are deploying LLM in production. Generating the first few tokens is fast, but as the sequence grows, each additional token takes progressively longer to generate – even if the model architecture and hardware remain the same.

If computation is not the primary bottleneck, what inefficiencies are causing this slowdown, and how would you redesign the inference process to make token generation significantly faster?

What is KV Caching and how does it make token generation faster?

KV caching is an optimization technique used during text generation in large language models to avoid unnecessary computation. In autoregressive generation, the model generates text one token at a time, and at each step it typically focuses on all previous tokens. However, the keys (K) and values ​​(V) calculated for earlier tokens never change.

With KV caching, the model stores these keys and values ​​the first time they are computed. When creating the next token, it reuses the cached K and V instead of recalculating from scratch, and calculates only the query (Q), key, and value for the new token. Attention is then calculated using the cached information and the new token.

This reuse of previous calculations substantially reduces redundant work, making inference faster and more efficient, especially for long sequences, at the expense of additional memory to store the cache. check it out Practice Notebook Here

Evaluating the impact of KV caching on inference speed

In this code, we benchmark the impact of KV caching during autoregressive text generation. We run the same prompt through the model multiple times, once with KV caching enabled and once without it, and measure the average generation time. Keeping the model, prompt, and generation length constant, this experiment isolates how reusing cached keys and values ​​reduces unnecessary attention computation and speeds up inference. check it out Practice Notebook Here

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "gpt2-medium"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

prompt = "Explain KV caching in transformers."

inputs = tokenizer(prompt, return_tensors="pt").to(device)

for use_cache in (True, False):
    times = ()
    for _ in range(5):  
        start = time.time()
        model.generate(
            **inputs,
            use_cache=use_cache,
            max_new_tokens=1000
        )
        times.append(time.time() - start)

    print(
        f"{'with' if use_cache else 'without'} KV caching: "
        f"{round(np.mean(times), 3)} ± {round(np.std(times), 3)} seconds"
    )

The results clearly demonstrate the impact of KV caching on inference speed. With KV caching enabled, it takes about 21.7 seconds to generate 1000 tokens, while disabling KV caching increases the generation time to over 107 seconds – about a 5× slowdown. This sharp difference occurs because, without KV caching, the model focuses on all previously generated tokens at every step, leading to a quadratic increase in computation. check it out Practice Notebook Here

With KV caching, previous keys and values ​​are reused, eliminating redundant work and keeping the generation time approximately linear as the sequence grows. This experiment highlights why KV caching is essential for efficient, real-world deployment of autoregressive language models.

check it out Practice Notebook Here



I am a Civil Engineering graduate (2022) from Jamia Millia Islamia, New Delhi, and I have a keen interest in Data Science, especially Neural Networks and their application in various fields.

Related Articles

Leave a Comment