Diff of /mmseg/utils/optimizer.py [000000] .. [4e96d3]

Switch to side-by-side view

--- a
+++ b/mmseg/utils/optimizer.py
@@ -0,0 +1,258 @@
+# gradient cumulative optimizer hooks are from
+# https://github.com/open-mmlab/mmcv/pull/1221
+# TODO use mmcv if the PR is merged
+from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook,
+                         allreduce_grads)
+from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+
+try:
+    import apex
+except ImportError:
+    apex = None
+
+
+@HOOKS.register_module()
+class ApexOptimizerHook(OptimizerHook):
+    """Optimizer hook for distributed training."""
+
+    def __init__(self,
+                 update_interval=1,
+                 grad_clip=None,
+                 coalesce=True,
+                 bucket_size_mb=-1,
+                 use_fp16=False):
+        self.grad_clip = grad_clip
+        self.coalesce = coalesce
+        self.bucket_size_mb = bucket_size_mb
+        self.update_interval = update_interval
+        self.use_fp16 = use_fp16
+
+    def before_run(self, runner):
+        runner.optimizer.zero_grad()
+
+    def after_train_iter(self, runner):
+        runner.outputs['loss'] /= self.update_interval
+        if self.use_fp16:
+            with apex.amp.scale_loss(runner.outputs['loss'],
+                                     runner.optimizer) as scaled_loss:
+                scaled_loss.backward()
+        else:
+            runner.outputs['loss'].backward()
+        if self.every_n_iters(runner, self.update_interval):
+            if self.grad_clip is not None:
+                self.clip_grads(runner.model.parameters())
+            runner.optimizer.step()
+            runner.optimizer.zero_grad()
+
+
+@HOOKS.register_module(force=True)
+class GradientCumulativeOptimizerHook(OptimizerHook):
+    """Optimizer Hook implements multi-iters gradient cumulating.
+    Args:
+        cumulative_iters (int, optional): Num of gradient cumulative iters.
+            The optimizer will step every `cumulative_iters` iters.
+            Defaults to 1.
+    Examples:
+        >>> # Use cumulative_iters to simulate a large batch size
+        >>> # It is helpful when the hardware cannot handle a large batch size.
+        >>> loader = DataLoader(data, batch_size=64)
+        >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+        >>> # almost equals to
+        >>> loader = DataLoader(data, batch_size=256)
+        >>> optim_hook = OptimizerHook()
+    """
+
+    def __init__(self, cumulative_iters=1, **kwargs):
+        super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+
+        assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+            f'cumulative_iters only accepts positive int, but got ' \
+            f'{type(cumulative_iters)} instead.'
+
+        self.cumulative_iters = cumulative_iters
+        self.divisible_iters = 0
+        self.remainder_iters = 0
+        self.initialized = False
+
+    def has_batch_norm(self, module):
+        if isinstance(module, _BatchNorm):
+            return True
+        for m in module.children():
+            if self.has_batch_norm(m):
+                return True
+        return False
+
+    def _init(self, runner):
+        if runner.iter % self.cumulative_iters != 0:
+            runner.logger.warning(
+                'Resume iter number is not divisible by cumulative_iters in '
+                'GradientCumulativeOptimizerHook, which means the gradient of '
+                'some iters is lost and the result may be influenced slightly.'
+            )
+
+        if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+            runner.logger.warning(
+                'GradientCumulativeOptimizerHook may slightly decrease '
+                'performance if the model has BatchNorm layers.')
+
+        residual_iters = runner.max_iters - runner.iter
+
+        self.divisible_iters = (
+            residual_iters // self.cumulative_iters * self.cumulative_iters)
+        self.remainder_iters = residual_iters - self.divisible_iters
+
+        self.initialized = True
+
+    def after_train_iter(self, runner):
+        if not self.initialized:
+            self._init(runner)
+
+        if runner.iter < self.divisible_iters:
+            loss_factor = self.cumulative_iters
+        else:
+            loss_factor = self.remainder_iters
+        loss = runner.outputs['loss']
+        loss = loss / loss_factor
+        loss.backward()
+
+        if (self.every_n_iters(runner, self.cumulative_iters)
+                or self.is_last_iter(runner)):
+
+            if self.grad_clip is not None:
+                grad_norm = self.clip_grads(runner.model.parameters())
+                if grad_norm is not None:
+                    # Add grad norm to the logger
+                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
+                                             runner.outputs['num_samples'])
+            runner.optimizer.step()
+            runner.optimizer.zero_grad()
+
+
+if (TORCH_VERSION != 'parrots'
+        and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+    @HOOKS.register_module(force=True)
+    class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+                                              Fp16OptimizerHook):
+        """Fp16 optimizer Hook (using PyTorch's implementation) implements
+        multi-iters gradient cumulating.
+        If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+        to take care of the optimization procedure.
+        """
+
+        def __init__(self, *args, **kwargs):
+            super(GradientCumulativeFp16OptimizerHook,
+                  self).__init__(*args, **kwargs)
+
+        def after_train_iter(self, runner):
+            if not self.initialized:
+                self._init(runner)
+
+            if runner.iter < self.divisible_iters:
+                loss_factor = self.cumulative_iters
+            else:
+                loss_factor = self.remainder_iters
+            loss = runner.outputs['loss']
+            loss = loss / loss_factor
+
+            self.loss_scaler.scale(loss).backward()
+
+            if (self.every_n_iters(runner, self.cumulative_iters)
+                    or self.is_last_iter(runner)):
+
+                # copy fp16 grads in the model to fp32 params in the optimizer
+                self.loss_scaler.unscale_(runner.optimizer)
+
+                if self.grad_clip is not None:
+                    grad_norm = self.clip_grads(runner.model.parameters())
+                    if grad_norm is not None:
+                        # Add grad norm to the logger
+                        runner.log_buffer.update(
+                            {'grad_norm': float(grad_norm)},
+                            runner.outputs['num_samples'])
+
+                # backward and update scaler
+                self.loss_scaler.step(runner.optimizer)
+                self.loss_scaler.update(self._scale_update_param)
+
+                # save state_dict of loss_scaler
+                runner.meta.setdefault(
+                    'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+                # clear grads
+                runner.model.zero_grad()
+                runner.optimizer.zero_grad()
+
+else:
+
+    @HOOKS.register_module(force=True)
+    class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+                                              Fp16OptimizerHook):
+        """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+        iters gradient cumulating."""
+
+        def __init__(self, *args, **kwargs):
+            super(GradientCumulativeFp16OptimizerHook,
+                  self).__init__(*args, **kwargs)
+
+        def after_train_iter(self, runner):
+            if not self.initialized:
+                self._init(runner)
+
+            if runner.iter < self.divisible_iters:
+                loss_factor = self.cumulative_iters
+            else:
+                loss_factor = self.remainder_iters
+
+            loss = runner.outputs['loss']
+            loss = loss / loss_factor
+
+            # scale the loss value
+            scaled_loss = loss * self.loss_scaler.loss_scale
+            scaled_loss.backward()
+
+            if (self.every_n_iters(runner, self.cumulative_iters)
+                    or self.is_last_iter(runner)):
+
+                # copy fp16 grads in the model to fp32 params in the optimizer
+                fp32_weights = []
+                for param_group in runner.optimizer.param_groups:
+                    fp32_weights += param_group['params']
+                self.copy_grads_to_fp32(runner.model, fp32_weights)
+                # allreduce grads
+                if self.distributed:
+                    allreduce_grads(fp32_weights, self.coalesce,
+                                    self.bucket_size_mb)
+
+                has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+                # if has overflow, skip this iteration
+                if not has_overflow:
+                    # scale the gradients back
+                    for param in fp32_weights:
+                        if param.grad is not None:
+                            param.grad.div_(self.loss_scaler.loss_scale)
+                    if self.grad_clip is not None:
+                        grad_norm = self.clip_grads(fp32_weights)
+                        if grad_norm is not None:
+                            # Add grad norm to the logger
+                            runner.log_buffer.update(
+                                {'grad_norm': float(grad_norm)},
+                                runner.outputs['num_samples'])
+                    # update fp32 params
+                    runner.optimizer.step()
+                    # copy fp32 params to the fp16 model
+                    self.copy_params_to_fp16(runner.model, fp32_weights)
+                else:
+                    runner.logger.warning(
+                        'Check overflow, downscale loss scale '
+                        f'to {self.loss_scaler.cur_scale}')
+
+                self.loss_scaler.update_scale(has_overflow)
+
+                # save state_dict of loss_scaler
+                runner.meta.setdefault(
+                    'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+                # clear grads
+                runner.model.zero_grad()
+                runner.optimizer.zero_grad()
\ No newline at end of file