|
a |
|
b/mmseg/utils/optimizer.py |
|
|
1 |
# gradient cumulative optimizer hooks are from |
|
|
2 |
# https://github.com/open-mmlab/mmcv/pull/1221 |
|
|
3 |
# TODO use mmcv if the PR is merged |
|
|
4 |
from mmcv.runner import (HOOKS, Fp16OptimizerHook, OptimizerHook, |
|
|
5 |
allreduce_grads) |
|
|
6 |
from mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version |
|
|
7 |
|
|
|
8 |
try: |
|
|
9 |
import apex |
|
|
10 |
except ImportError: |
|
|
11 |
apex = None |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
@HOOKS.register_module() |
|
|
15 |
class ApexOptimizerHook(OptimizerHook): |
|
|
16 |
"""Optimizer hook for distributed training.""" |
|
|
17 |
|
|
|
18 |
def __init__(self, |
|
|
19 |
update_interval=1, |
|
|
20 |
grad_clip=None, |
|
|
21 |
coalesce=True, |
|
|
22 |
bucket_size_mb=-1, |
|
|
23 |
use_fp16=False): |
|
|
24 |
self.grad_clip = grad_clip |
|
|
25 |
self.coalesce = coalesce |
|
|
26 |
self.bucket_size_mb = bucket_size_mb |
|
|
27 |
self.update_interval = update_interval |
|
|
28 |
self.use_fp16 = use_fp16 |
|
|
29 |
|
|
|
30 |
def before_run(self, runner): |
|
|
31 |
runner.optimizer.zero_grad() |
|
|
32 |
|
|
|
33 |
def after_train_iter(self, runner): |
|
|
34 |
runner.outputs['loss'] /= self.update_interval |
|
|
35 |
if self.use_fp16: |
|
|
36 |
with apex.amp.scale_loss(runner.outputs['loss'], |
|
|
37 |
runner.optimizer) as scaled_loss: |
|
|
38 |
scaled_loss.backward() |
|
|
39 |
else: |
|
|
40 |
runner.outputs['loss'].backward() |
|
|
41 |
if self.every_n_iters(runner, self.update_interval): |
|
|
42 |
if self.grad_clip is not None: |
|
|
43 |
self.clip_grads(runner.model.parameters()) |
|
|
44 |
runner.optimizer.step() |
|
|
45 |
runner.optimizer.zero_grad() |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
@HOOKS.register_module(force=True) |
|
|
49 |
class GradientCumulativeOptimizerHook(OptimizerHook): |
|
|
50 |
"""Optimizer Hook implements multi-iters gradient cumulating. |
|
|
51 |
Args: |
|
|
52 |
cumulative_iters (int, optional): Num of gradient cumulative iters. |
|
|
53 |
The optimizer will step every `cumulative_iters` iters. |
|
|
54 |
Defaults to 1. |
|
|
55 |
Examples: |
|
|
56 |
>>> # Use cumulative_iters to simulate a large batch size |
|
|
57 |
>>> # It is helpful when the hardware cannot handle a large batch size. |
|
|
58 |
>>> loader = DataLoader(data, batch_size=64) |
|
|
59 |
>>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) |
|
|
60 |
>>> # almost equals to |
|
|
61 |
>>> loader = DataLoader(data, batch_size=256) |
|
|
62 |
>>> optim_hook = OptimizerHook() |
|
|
63 |
""" |
|
|
64 |
|
|
|
65 |
def __init__(self, cumulative_iters=1, **kwargs): |
|
|
66 |
super(GradientCumulativeOptimizerHook, self).__init__(**kwargs) |
|
|
67 |
|
|
|
68 |
assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ |
|
|
69 |
f'cumulative_iters only accepts positive int, but got ' \ |
|
|
70 |
f'{type(cumulative_iters)} instead.' |
|
|
71 |
|
|
|
72 |
self.cumulative_iters = cumulative_iters |
|
|
73 |
self.divisible_iters = 0 |
|
|
74 |
self.remainder_iters = 0 |
|
|
75 |
self.initialized = False |
|
|
76 |
|
|
|
77 |
def has_batch_norm(self, module): |
|
|
78 |
if isinstance(module, _BatchNorm): |
|
|
79 |
return True |
|
|
80 |
for m in module.children(): |
|
|
81 |
if self.has_batch_norm(m): |
|
|
82 |
return True |
|
|
83 |
return False |
|
|
84 |
|
|
|
85 |
def _init(self, runner): |
|
|
86 |
if runner.iter % self.cumulative_iters != 0: |
|
|
87 |
runner.logger.warning( |
|
|
88 |
'Resume iter number is not divisible by cumulative_iters in ' |
|
|
89 |
'GradientCumulativeOptimizerHook, which means the gradient of ' |
|
|
90 |
'some iters is lost and the result may be influenced slightly.' |
|
|
91 |
) |
|
|
92 |
|
|
|
93 |
if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: |
|
|
94 |
runner.logger.warning( |
|
|
95 |
'GradientCumulativeOptimizerHook may slightly decrease ' |
|
|
96 |
'performance if the model has BatchNorm layers.') |
|
|
97 |
|
|
|
98 |
residual_iters = runner.max_iters - runner.iter |
|
|
99 |
|
|
|
100 |
self.divisible_iters = ( |
|
|
101 |
residual_iters // self.cumulative_iters * self.cumulative_iters) |
|
|
102 |
self.remainder_iters = residual_iters - self.divisible_iters |
|
|
103 |
|
|
|
104 |
self.initialized = True |
|
|
105 |
|
|
|
106 |
def after_train_iter(self, runner): |
|
|
107 |
if not self.initialized: |
|
|
108 |
self._init(runner) |
|
|
109 |
|
|
|
110 |
if runner.iter < self.divisible_iters: |
|
|
111 |
loss_factor = self.cumulative_iters |
|
|
112 |
else: |
|
|
113 |
loss_factor = self.remainder_iters |
|
|
114 |
loss = runner.outputs['loss'] |
|
|
115 |
loss = loss / loss_factor |
|
|
116 |
loss.backward() |
|
|
117 |
|
|
|
118 |
if (self.every_n_iters(runner, self.cumulative_iters) |
|
|
119 |
or self.is_last_iter(runner)): |
|
|
120 |
|
|
|
121 |
if self.grad_clip is not None: |
|
|
122 |
grad_norm = self.clip_grads(runner.model.parameters()) |
|
|
123 |
if grad_norm is not None: |
|
|
124 |
# Add grad norm to the logger |
|
|
125 |
runner.log_buffer.update({'grad_norm': float(grad_norm)}, |
|
|
126 |
runner.outputs['num_samples']) |
|
|
127 |
runner.optimizer.step() |
|
|
128 |
runner.optimizer.zero_grad() |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
if (TORCH_VERSION != 'parrots' |
|
|
132 |
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): |
|
|
133 |
|
|
|
134 |
@HOOKS.register_module(force=True) |
|
|
135 |
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, |
|
|
136 |
Fp16OptimizerHook): |
|
|
137 |
"""Fp16 optimizer Hook (using PyTorch's implementation) implements |
|
|
138 |
multi-iters gradient cumulating. |
|
|
139 |
If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, |
|
|
140 |
to take care of the optimization procedure. |
|
|
141 |
""" |
|
|
142 |
|
|
|
143 |
def __init__(self, *args, **kwargs): |
|
|
144 |
super(GradientCumulativeFp16OptimizerHook, |
|
|
145 |
self).__init__(*args, **kwargs) |
|
|
146 |
|
|
|
147 |
def after_train_iter(self, runner): |
|
|
148 |
if not self.initialized: |
|
|
149 |
self._init(runner) |
|
|
150 |
|
|
|
151 |
if runner.iter < self.divisible_iters: |
|
|
152 |
loss_factor = self.cumulative_iters |
|
|
153 |
else: |
|
|
154 |
loss_factor = self.remainder_iters |
|
|
155 |
loss = runner.outputs['loss'] |
|
|
156 |
loss = loss / loss_factor |
|
|
157 |
|
|
|
158 |
self.loss_scaler.scale(loss).backward() |
|
|
159 |
|
|
|
160 |
if (self.every_n_iters(runner, self.cumulative_iters) |
|
|
161 |
or self.is_last_iter(runner)): |
|
|
162 |
|
|
|
163 |
# copy fp16 grads in the model to fp32 params in the optimizer |
|
|
164 |
self.loss_scaler.unscale_(runner.optimizer) |
|
|
165 |
|
|
|
166 |
if self.grad_clip is not None: |
|
|
167 |
grad_norm = self.clip_grads(runner.model.parameters()) |
|
|
168 |
if grad_norm is not None: |
|
|
169 |
# Add grad norm to the logger |
|
|
170 |
runner.log_buffer.update( |
|
|
171 |
{'grad_norm': float(grad_norm)}, |
|
|
172 |
runner.outputs['num_samples']) |
|
|
173 |
|
|
|
174 |
# backward and update scaler |
|
|
175 |
self.loss_scaler.step(runner.optimizer) |
|
|
176 |
self.loss_scaler.update(self._scale_update_param) |
|
|
177 |
|
|
|
178 |
# save state_dict of loss_scaler |
|
|
179 |
runner.meta.setdefault( |
|
|
180 |
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() |
|
|
181 |
|
|
|
182 |
# clear grads |
|
|
183 |
runner.model.zero_grad() |
|
|
184 |
runner.optimizer.zero_grad() |
|
|
185 |
|
|
|
186 |
else: |
|
|
187 |
|
|
|
188 |
@HOOKS.register_module(force=True) |
|
|
189 |
class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, |
|
|
190 |
Fp16OptimizerHook): |
|
|
191 |
"""Fp16 optimizer Hook (using mmcv implementation) implements multi- |
|
|
192 |
iters gradient cumulating.""" |
|
|
193 |
|
|
|
194 |
def __init__(self, *args, **kwargs): |
|
|
195 |
super(GradientCumulativeFp16OptimizerHook, |
|
|
196 |
self).__init__(*args, **kwargs) |
|
|
197 |
|
|
|
198 |
def after_train_iter(self, runner): |
|
|
199 |
if not self.initialized: |
|
|
200 |
self._init(runner) |
|
|
201 |
|
|
|
202 |
if runner.iter < self.divisible_iters: |
|
|
203 |
loss_factor = self.cumulative_iters |
|
|
204 |
else: |
|
|
205 |
loss_factor = self.remainder_iters |
|
|
206 |
|
|
|
207 |
loss = runner.outputs['loss'] |
|
|
208 |
loss = loss / loss_factor |
|
|
209 |
|
|
|
210 |
# scale the loss value |
|
|
211 |
scaled_loss = loss * self.loss_scaler.loss_scale |
|
|
212 |
scaled_loss.backward() |
|
|
213 |
|
|
|
214 |
if (self.every_n_iters(runner, self.cumulative_iters) |
|
|
215 |
or self.is_last_iter(runner)): |
|
|
216 |
|
|
|
217 |
# copy fp16 grads in the model to fp32 params in the optimizer |
|
|
218 |
fp32_weights = [] |
|
|
219 |
for param_group in runner.optimizer.param_groups: |
|
|
220 |
fp32_weights += param_group['params'] |
|
|
221 |
self.copy_grads_to_fp32(runner.model, fp32_weights) |
|
|
222 |
# allreduce grads |
|
|
223 |
if self.distributed: |
|
|
224 |
allreduce_grads(fp32_weights, self.coalesce, |
|
|
225 |
self.bucket_size_mb) |
|
|
226 |
|
|
|
227 |
has_overflow = self.loss_scaler.has_overflow(fp32_weights) |
|
|
228 |
# if has overflow, skip this iteration |
|
|
229 |
if not has_overflow: |
|
|
230 |
# scale the gradients back |
|
|
231 |
for param in fp32_weights: |
|
|
232 |
if param.grad is not None: |
|
|
233 |
param.grad.div_(self.loss_scaler.loss_scale) |
|
|
234 |
if self.grad_clip is not None: |
|
|
235 |
grad_norm = self.clip_grads(fp32_weights) |
|
|
236 |
if grad_norm is not None: |
|
|
237 |
# Add grad norm to the logger |
|
|
238 |
runner.log_buffer.update( |
|
|
239 |
{'grad_norm': float(grad_norm)}, |
|
|
240 |
runner.outputs['num_samples']) |
|
|
241 |
# update fp32 params |
|
|
242 |
runner.optimizer.step() |
|
|
243 |
# copy fp32 params to the fp16 model |
|
|
244 |
self.copy_params_to_fp16(runner.model, fp32_weights) |
|
|
245 |
else: |
|
|
246 |
runner.logger.warning( |
|
|
247 |
'Check overflow, downscale loss scale ' |
|
|
248 |
f'to {self.loss_scaler.cur_scale}') |
|
|
249 |
|
|
|
250 |
self.loss_scaler.update_scale(has_overflow) |
|
|
251 |
|
|
|
252 |
# save state_dict of loss_scaler |
|
|
253 |
runner.meta.setdefault( |
|
|
254 |
'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() |
|
|
255 |
|
|
|
256 |
# clear grads |
|
|
257 |
runner.model.zero_grad() |
|
|
258 |
runner.optimizer.zero_grad() |