KV Cache

During generation in auto regressive model, in the Masked Self-Attention module, the model has to generate query from the previous generated token and key and value for all the previous tokens. Now if that key, value is once generated, we can reuse it in the future tokens.

That is what KV cache does, it saves the key value vectors in the cache and it increases by every token generation.

Why KV cache can become large?

  1. Number of tokens: with each token generation, the KV cache adds another key value pair
  2. Batch Size: the batch size multiplies the KV cache by the batch size
  3. Number of beams: for each beam the KV cache is different, while the prompt KV cache can be shared, the decoding ones can't be
  4. Number of heads / head dim: KV cache depends linearly on the number of heads as each head will have its own KV cache, also increasing head dimension will also increase the KV cache.

Issues

  1. 20-40% utilization: KV cache is stored early for the whole max size, but not all request have the max size (solved by Paged KV Cache)
  2. Need continuous memory

How to optimize?

  1. Sliding KV Cache
  2. Paged KV Cache
  3. Group-Query Attention
  4. Multi-Query Attention

References


Related Notes