Group-Query Attention
In Multi-Head Attention, every head (i.e. 8) has its own query, key and value. So for 8 head attention, there are 8 queries, 8 keys and 8 values.
In Multi-Query Attention, every head shares one key and value.
While multi-head gives the best result, it takes too much VRAM and inference time.
Multi-query takes very less VRAM and faster inference, but at the cost of performance degradation.
On the other hand, in the grouped-query attention, they use the best of both worlds, where some of the heads share one key and value. So for example, if the group size is 2 and head size is 8, there there will be 8 queries, 4 keys and 4 values. Every 2 query will share the same key and value.
This (1) lowers the total number of parameters, (2) reduces the memory bandwidth usage for key and value .
