Switch to side-by-side view

--- a
+++ b/mmaction/utils/precise_bn.py
@@ -0,0 +1,155 @@
+# Adapted from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/precise_bn.py  # noqa: E501
+# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0  # noqa: E501
+
+import logging
+import time
+
+import mmcv
+import torch
+from mmcv.parallel import MMDistributedDataParallel
+from mmcv.runner import Hook
+from mmcv.utils import print_log
+from torch.nn import GroupNorm
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.instancenorm import _InstanceNorm
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+from torch.utils.data import DataLoader
+
+
+def is_parallel_module(module):
+    """Check if a module is a parallel module.
+
+    The following 3 modules (and their subclasses) are regarded as parallel
+    modules: DataParallel, DistributedDataParallel,
+    MMDistributedDataParallel (the deprecated version).
+
+    Args:
+        module (nn.Module): The module to be checked.
+    Returns:
+        bool: True if the input module is a parallel module.
+    """
+    parallels = (DataParallel, DistributedDataParallel,
+                 MMDistributedDataParallel)
+    return bool(isinstance(module, parallels))
+
+
+@torch.no_grad()
+def update_bn_stats(model, data_loader, num_iters=200, logger=None):
+    """Recompute and update the batch norm stats to make them more precise.
+
+    During
+    training both BN stats and the weight are changing after every iteration,
+    so the running average can not precisely reflect the actual stats of the
+    current model.
+    In this function, the BN stats are recomputed with fixed weights, to make
+    the running average more precise. Specifically, it computes the true
+    average of per-batch mean/variance instead of the running average.
+
+    Args:
+        model (nn.Module): The model whose bn stats will be recomputed.
+        data_loader (iterator): The DataLoader iterator.
+        num_iters (int): number of iterations to compute the stats.
+        logger (:obj:`logging.Logger` | None): Logger for logging.
+            Default: None.
+    """
+
+    model.train()
+
+    assert len(data_loader) >= num_iters, (
+        f'length of dataloader {len(data_loader)} must be greater than '
+        f'iteration number {num_iters}')
+
+    if is_parallel_module(model):
+        parallel_module = model
+        model = model.module
+    else:
+        parallel_module = model
+    # Finds all the bn layers with training=True.
+    bn_layers = [
+        m for m in model.modules() if m.training and isinstance(m, _BatchNorm)
+    ]
+
+    if len(bn_layers) == 0:
+        print_log('No BN found in model', logger=logger, level=logging.WARNING)
+        return
+    print_log(f'{len(bn_layers)} BN found', logger=logger)
+
+    # Finds all the other norm layers with training=True.
+    for m in model.modules():
+        if m.training and isinstance(m, (_InstanceNorm, GroupNorm)):
+            print_log(
+                'IN/GN stats will be updated like training.',
+                logger=logger,
+                level=logging.WARNING)
+
+    # In order to make the running stats only reflect the current batch, the
+    # momentum is disabled.
+    # bn.running_mean = (1 - momentum) * bn.running_mean + momentum *
+    # batch_mean
+    # Setting the momentum to 1.0 to compute the stats without momentum.
+    momentum_actual = [bn.momentum for bn in bn_layers]  # pyre-ignore
+    for bn in bn_layers:
+        bn.momentum = 1.0
+
+    # Note that running_var actually means "running average of variance"
+    running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
+    running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
+
+    finish_before_loader = False
+    prog_bar = mmcv.ProgressBar(len(data_loader))
+    for ind, data in enumerate(data_loader):
+        with torch.no_grad():
+            parallel_module(**data, return_loss=False)
+        prog_bar.update()
+        for i, bn in enumerate(bn_layers):
+            # Accumulates the bn stats.
+            running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
+            # running var is actually
+            running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
+
+        if (ind + 1) >= num_iters:
+            finish_before_loader = True
+            break
+    assert finish_before_loader, 'Dataloader stopped before ' \
+                                 f'iteration {num_iters}'
+
+    for i, bn in enumerate(bn_layers):
+        # Sets the precise bn stats.
+        bn.running_mean = running_mean[i]
+        bn.running_var = running_var[i]
+        bn.momentum = momentum_actual[i]
+
+
+class PreciseBNHook(Hook):
+    """Precise BN hook.
+
+    Attributes:
+        dataloader (DataLoader): A PyTorch dataloader.
+        num_iters (int): Number of iterations to update the bn stats.
+            Default: 200.
+        interval (int): Perform precise bn interval (by epochs). Default: 1.
+    """
+
+    def __init__(self, dataloader, num_iters=200, interval=1):
+        if not isinstance(dataloader, DataLoader):
+            raise TypeError('dataloader must be a pytorch DataLoader, but got'
+                            f' {type(dataloader)}')
+        self.dataloader = dataloader
+        self.interval = interval
+        self.num_iters = num_iters
+
+    def after_train_epoch(self, runner):
+        if self.every_n_epochs(runner, self.interval):
+            # sleep to avoid possible deadlock
+            time.sleep(2.)
+            print_log(
+                f'Running Precise BN for {self.num_iters} iterations',
+                logger=runner.logger)
+            update_bn_stats(
+                runner.model,
+                self.dataloader,
+                self.num_iters,
+                logger=runner.logger)
+            print_log('BN stats updated', logger=runner.logger)
+            # sleep to avoid possible deadlock
+            time.sleep(2.)