Group Normalization

Group normalization is middle ground approach of Batch Normalization and Layer Normalization, where data is divided into subgroups and then calculated the mean and variance to normalize that subgroup.

Code

BATCH_SIZE = 2
HIDDEN_DIM = 100
NUM_OF_GROUPS = 4
features = torch.randn((BATCH_SIZE, HIDDEN_DIM)) # [BATCH_SIZE, HIDDEN_DIM]
features_grp = features.view(BATCH_SIZE, NUM_OF_GROUPS, -1) # [BATCH_SIZE, NUM_OF_GROUPS, 25 (100//4)]
mean = features.mean(dim=-1, keepdims=True) # [BATCH_SIZE, NUM_OF_GROUPS, 1]
var = features.var(dim=-1, keepdims=True) # [BATCH_SIZE, NUM_OF_GROUPS, 1]
features_grp = (features - mean) / var # [BATCH_SIZE, NUM_OF_GROUPS, 25 (100//4)]
features = features_grp.view(BATCH_SIZE, HIDDEN_DIM) # [BATCH_SIZE, HIDDEN_DIM]

References


Related Notes