Group-Query Attention

In Multi-Head Attention, every head (i.e. 8) has its own query, key and value. So for 8 head attention, there are 8 queries, 8 keys and 8 values.

In Multi-Query Attention, every head shares one key and value.

While multi-head gives the best result, it takes too much VRAM and inference time.
Multi-query takes very less VRAM and faster inference, but at the cost of performance degradation.

On the other hand, in the grouped-query attention, they use the best of both worlds, where some of the heads share one key and value. So for example, if the group size is 2 and head size is 8, there there will be 8 queries, 4 keys and 4 values. Every 2 query will share the same key and value.

This (1) lowers the total number of parameters, (2) reduces the memory bandwidth usage for key and value .


References

  1. https://sebastianraschka.com/llms-from-scratch/ch04/04_gqa/
  2. https://www.youtube.com/watch?v=rCJfyw7XBx8&list=PLfSv7CK7EjD2fC9S6MAKRNDgTSCYgdGgz&index=3

Related Notes