# Adapted from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/precise_bn.py # noqa: E501
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
import logging
import time
import mmcv
import torch
from mmcv.parallel import MMDistributedDataParallel
from mmcv.runner import Hook
from mmcv.utils import print_log
from torch.nn import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
def is_parallel_module(module):
"""Check if a module is a parallel module.
The following 3 modules (and their subclasses) are regarded as parallel
modules: DataParallel, DistributedDataParallel,
MMDistributedDataParallel (the deprecated version).
Args:
module (nn.Module): The module to be checked.
Returns:
bool: True if the input module is a parallel module.
"""
parallels = (DataParallel, DistributedDataParallel,
MMDistributedDataParallel)
return bool(isinstance(module, parallels))
@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters=200, logger=None):
"""Recompute and update the batch norm stats to make them more precise.
During
training both BN stats and the weight are changing after every iteration,
so the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true
average of per-batch mean/variance instead of the running average.
Args:
model (nn.Module): The model whose bn stats will be recomputed.
data_loader (iterator): The DataLoader iterator.
num_iters (int): number of iterations to compute the stats.
logger (:obj:`logging.Logger` | None): Logger for logging.
Default: None.
"""
model.train()
assert len(data_loader) >= num_iters, (
f'length of dataloader {len(data_loader)} must be greater than '
f'iteration number {num_iters}')
if is_parallel_module(model):
parallel_module = model
model = model.module
else:
parallel_module = model
# Finds all the bn layers with training=True.
bn_layers = [
m for m in model.modules() if m.training and isinstance(m, _BatchNorm)
]
if len(bn_layers) == 0:
print_log('No BN found in model', logger=logger, level=logging.WARNING)
return
print_log(f'{len(bn_layers)} BN found', logger=logger)
# Finds all the other norm layers with training=True.
for m in model.modules():
if m.training and isinstance(m, (_InstanceNorm, GroupNorm)):
print_log(
'IN/GN stats will be updated like training.',
logger=logger,
level=logging.WARNING)
# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum *
# batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers] # pyre-ignore
for bn in bn_layers:
bn.momentum = 1.0
# Note that running_var actually means "running average of variance"
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
finish_before_loader = False
prog_bar = mmcv.ProgressBar(len(data_loader))
for ind, data in enumerate(data_loader):
with torch.no_grad():
parallel_module(**data, return_loss=False)
prog_bar.update()
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
# running var is actually
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
if (ind + 1) >= num_iters:
finish_before_loader = True
break
assert finish_before_loader, 'Dataloader stopped before ' \
f'iteration {num_iters}'
for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]
class PreciseBNHook(Hook):
"""Precise BN hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
num_iters (int): Number of iterations to update the bn stats.
Default: 200.
interval (int): Perform precise bn interval (by epochs). Default: 1.
"""
def __init__(self, dataloader, num_iters=200, interval=1):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.num_iters = num_iters
def after_train_epoch(self, runner):
if self.every_n_epochs(runner, self.interval):
# sleep to avoid possible deadlock
time.sleep(2.)
print_log(
f'Running Precise BN for {self.num_iters} iterations',
logger=runner.logger)
update_bn_stats(
runner.model,
self.dataloader,
self.num_iters,
logger=runner.logger)
print_log('BN stats updated', logger=runner.logger)
# sleep to avoid possible deadlock
time.sleep(2.)