Group Normalization
- Proposed in this paper
Group normalization is middle ground approach of Batch Normalization and Layer Normalization, where data is divided into subgroups and then calculated the mean and variance to normalize that subgroup.
Code
BATCH_SIZE = 2
HIDDEN_DIM = 100
NUM_OF_GROUPS = 4
features = torch.randn((BATCH_SIZE, HIDDEN_DIM)) # [BATCH_SIZE, HIDDEN_DIM]
features_grp = features.view(BATCH_SIZE, NUM_OF_GROUPS, -1) # [BATCH_SIZE, NUM_OF_GROUPS, 25 (100//4)]
mean = features.mean(dim=-1, keepdims=True) # [BATCH_SIZE, NUM_OF_GROUPS, 1]
var = features.var(dim=-1, keepdims=True) # [BATCH_SIZE, NUM_OF_GROUPS, 1]
features_grp = (features - mean) / var # [BATCH_SIZE, NUM_OF_GROUPS, 25 (100//4)]
features = features_grp.view(BATCH_SIZE, HIDDEN_DIM) # [BATCH_SIZE, HIDDEN_DIM]