|
a |
|
b/mmaction/apis/train.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import copy as cp |
|
|
3 |
import os |
|
|
4 |
import os.path as osp |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import numpy as np |
|
|
8 |
import torch |
|
|
9 |
import torch.distributed as dist |
|
|
10 |
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
|
11 |
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, |
|
|
12 |
build_optimizer, get_dist_info) |
|
|
13 |
from mmcv.runner.hooks import Fp16OptimizerHook |
|
|
14 |
|
|
|
15 |
from ..core import (DistEvalHook, EvalHook, OmniSourceDistSamplerSeedHook, |
|
|
16 |
OmniSourceRunner) |
|
|
17 |
from ..datasets import build_dataloader, build_dataset |
|
|
18 |
from ..utils import PreciseBNHook, get_root_logger |
|
|
19 |
from .test import multi_gpu_test |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def init_random_seed(seed=None, device='cuda'): |
|
|
23 |
"""Initialize random seed. |
|
|
24 |
|
|
|
25 |
If the seed is not set, the seed will be automatically randomized, |
|
|
26 |
and then broadcast to all processes to prevent some potential bugs. |
|
|
27 |
Args: |
|
|
28 |
seed (int, Optional): The seed. Default to None. |
|
|
29 |
device (str): The device where the seed will be put on. |
|
|
30 |
Default to 'cuda'. |
|
|
31 |
Returns: |
|
|
32 |
int: Seed to be used. |
|
|
33 |
""" |
|
|
34 |
if seed is not None: |
|
|
35 |
return seed |
|
|
36 |
|
|
|
37 |
# Make sure all ranks share the same random seed to prevent |
|
|
38 |
# some potential bugs. Please refer to |
|
|
39 |
# https://github.com/open-mmlab/mmdetection/issues/6339 |
|
|
40 |
rank, world_size = get_dist_info() |
|
|
41 |
seed = np.random.randint(2**31) |
|
|
42 |
|
|
|
43 |
if world_size == 1: |
|
|
44 |
return seed |
|
|
45 |
|
|
|
46 |
if rank == 0: |
|
|
47 |
random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
|
|
48 |
else: |
|
|
49 |
random_num = torch.tensor(0, dtype=torch.int32, device=device) |
|
|
50 |
|
|
|
51 |
dist.broadcast(random_num, src=0) |
|
|
52 |
return random_num.item() |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def train_model(model, |
|
|
56 |
dataset, |
|
|
57 |
cfg, |
|
|
58 |
distributed=False, |
|
|
59 |
validate=False, |
|
|
60 |
test=dict(test_best=False, test_last=False), |
|
|
61 |
timestamp=None, |
|
|
62 |
meta=None): |
|
|
63 |
"""Train model entry function. |
|
|
64 |
|
|
|
65 |
Args: |
|
|
66 |
model (nn.Module): The model to be trained. |
|
|
67 |
dataset (:obj:`Dataset`): Train dataset. |
|
|
68 |
cfg (dict): The config dict for training. |
|
|
69 |
distributed (bool): Whether to use distributed training. |
|
|
70 |
Default: False. |
|
|
71 |
validate (bool): Whether to do evaluation. Default: False. |
|
|
72 |
test (dict): The testing option, with two keys: test_last & test_best. |
|
|
73 |
The value is True or False, indicating whether to test the |
|
|
74 |
corresponding checkpoint. |
|
|
75 |
Default: dict(test_best=False, test_last=False). |
|
|
76 |
timestamp (str | None): Local time for runner. Default: None. |
|
|
77 |
meta (dict | None): Meta dict to record some important information. |
|
|
78 |
Default: None |
|
|
79 |
""" |
|
|
80 |
logger = get_root_logger(log_level=cfg.log_level) |
|
|
81 |
|
|
|
82 |
# prepare data loaders |
|
|
83 |
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
|
|
84 |
|
|
|
85 |
dataloader_setting = dict( |
|
|
86 |
videos_per_gpu=cfg.data.get('videos_per_gpu', 1), |
|
|
87 |
workers_per_gpu=cfg.data.get('workers_per_gpu', 1), |
|
|
88 |
persistent_workers=cfg.data.get('persistent_workers', False), |
|
|
89 |
num_gpus=len(cfg.gpu_ids), |
|
|
90 |
dist=distributed, |
|
|
91 |
seed=cfg.seed) |
|
|
92 |
dataloader_setting = dict(dataloader_setting, |
|
|
93 |
**cfg.data.get('train_dataloader', {})) |
|
|
94 |
|
|
|
95 |
if cfg.omnisource: |
|
|
96 |
# The option can override videos_per_gpu |
|
|
97 |
train_ratio = cfg.data.get('train_ratio', [1] * len(dataset)) |
|
|
98 |
omni_videos_per_gpu = cfg.data.get('omni_videos_per_gpu', None) |
|
|
99 |
if omni_videos_per_gpu is None: |
|
|
100 |
dataloader_settings = [dataloader_setting] * len(dataset) |
|
|
101 |
else: |
|
|
102 |
dataloader_settings = [] |
|
|
103 |
for videos_per_gpu in omni_videos_per_gpu: |
|
|
104 |
this_setting = cp.deepcopy(dataloader_setting) |
|
|
105 |
this_setting['videos_per_gpu'] = videos_per_gpu |
|
|
106 |
dataloader_settings.append(this_setting) |
|
|
107 |
data_loaders = [ |
|
|
108 |
build_dataloader(ds, **setting) |
|
|
109 |
for ds, setting in zip(dataset, dataloader_settings) |
|
|
110 |
] |
|
|
111 |
|
|
|
112 |
else: |
|
|
113 |
data_loaders = [ |
|
|
114 |
build_dataloader(ds, **dataloader_setting) for ds in dataset |
|
|
115 |
] |
|
|
116 |
|
|
|
117 |
# put model on gpus |
|
|
118 |
if distributed: |
|
|
119 |
find_unused_parameters = cfg.get('find_unused_parameters', False) |
|
|
120 |
# Sets the `find_unused_parameters` parameter in |
|
|
121 |
# torch.nn.parallel.DistributedDataParallel |
|
|
122 |
model = MMDistributedDataParallel( |
|
|
123 |
model.cuda(), |
|
|
124 |
device_ids=[torch.cuda.current_device()], |
|
|
125 |
broadcast_buffers=False, |
|
|
126 |
find_unused_parameters=find_unused_parameters) |
|
|
127 |
else: |
|
|
128 |
model = MMDataParallel( |
|
|
129 |
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) |
|
|
130 |
|
|
|
131 |
# build runner |
|
|
132 |
optimizer = build_optimizer(model, cfg.optimizer) |
|
|
133 |
|
|
|
134 |
Runner = OmniSourceRunner if cfg.omnisource else EpochBasedRunner |
|
|
135 |
runner = Runner( |
|
|
136 |
model, |
|
|
137 |
optimizer=optimizer, |
|
|
138 |
work_dir=cfg.work_dir, |
|
|
139 |
logger=logger, |
|
|
140 |
meta=meta) |
|
|
141 |
# an ugly workaround to make .log and .log.json filenames the same |
|
|
142 |
runner.timestamp = timestamp |
|
|
143 |
|
|
|
144 |
# fp16 setting |
|
|
145 |
fp16_cfg = cfg.get('fp16', None) |
|
|
146 |
if fp16_cfg is not None: |
|
|
147 |
optimizer_config = Fp16OptimizerHook( |
|
|
148 |
**cfg.optimizer_config, **fp16_cfg, distributed=distributed) |
|
|
149 |
elif distributed and 'type' not in cfg.optimizer_config: |
|
|
150 |
optimizer_config = OptimizerHook(**cfg.optimizer_config) |
|
|
151 |
else: |
|
|
152 |
optimizer_config = cfg.optimizer_config |
|
|
153 |
|
|
|
154 |
# register hooks |
|
|
155 |
runner.register_training_hooks(cfg.lr_config, optimizer_config, |
|
|
156 |
cfg.checkpoint_config, cfg.log_config, |
|
|
157 |
cfg.get('momentum_config', None)) |
|
|
158 |
if distributed: |
|
|
159 |
if cfg.omnisource: |
|
|
160 |
runner.register_hook(OmniSourceDistSamplerSeedHook()) |
|
|
161 |
else: |
|
|
162 |
runner.register_hook(DistSamplerSeedHook()) |
|
|
163 |
|
|
|
164 |
# precise bn setting |
|
|
165 |
if cfg.get('precise_bn', False): |
|
|
166 |
precise_bn_dataset = build_dataset(cfg.data.train) |
|
|
167 |
dataloader_setting = dict( |
|
|
168 |
videos_per_gpu=cfg.data.get('videos_per_gpu', 1), |
|
|
169 |
workers_per_gpu=1, # save memory and time |
|
|
170 |
persistent_workers=cfg.data.get('persistent_workers', False), |
|
|
171 |
num_gpus=len(cfg.gpu_ids), |
|
|
172 |
dist=distributed, |
|
|
173 |
seed=cfg.seed) |
|
|
174 |
data_loader_precise_bn = build_dataloader(precise_bn_dataset, |
|
|
175 |
**dataloader_setting) |
|
|
176 |
precise_bn_hook = PreciseBNHook(data_loader_precise_bn, |
|
|
177 |
**cfg.get('precise_bn')) |
|
|
178 |
runner.register_hook(precise_bn_hook) |
|
|
179 |
|
|
|
180 |
if validate: |
|
|
181 |
eval_cfg = cfg.get('evaluation', {}) |
|
|
182 |
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) |
|
|
183 |
dataloader_setting = dict( |
|
|
184 |
videos_per_gpu=cfg.data.get('videos_per_gpu', 1), |
|
|
185 |
workers_per_gpu=cfg.data.get('workers_per_gpu', 1), |
|
|
186 |
persistent_workers=cfg.data.get('persistent_workers', False), |
|
|
187 |
# cfg.gpus will be ignored if distributed |
|
|
188 |
num_gpus=len(cfg.gpu_ids), |
|
|
189 |
dist=distributed, |
|
|
190 |
shuffle=False) |
|
|
191 |
dataloader_setting = dict(dataloader_setting, |
|
|
192 |
**cfg.data.get('val_dataloader', {})) |
|
|
193 |
val_dataloader = build_dataloader(val_dataset, **dataloader_setting) |
|
|
194 |
eval_hook = DistEvalHook(val_dataloader, **eval_cfg) if distributed \ |
|
|
195 |
else EvalHook(val_dataloader, **eval_cfg) |
|
|
196 |
runner.register_hook(eval_hook) |
|
|
197 |
|
|
|
198 |
if cfg.resume_from: |
|
|
199 |
runner.resume(cfg.resume_from) |
|
|
200 |
elif cfg.load_from: |
|
|
201 |
runner.load_checkpoint(cfg.load_from) |
|
|
202 |
runner_kwargs = dict() |
|
|
203 |
if cfg.omnisource: |
|
|
204 |
runner_kwargs = dict(train_ratio=train_ratio) |
|
|
205 |
runner.run(data_loaders, cfg.workflow, cfg.total_epochs, **runner_kwargs) |
|
|
206 |
|
|
|
207 |
dist.barrier() |
|
|
208 |
time.sleep(5) |
|
|
209 |
|
|
|
210 |
if test['test_last'] or test['test_best']: |
|
|
211 |
best_ckpt_path = None |
|
|
212 |
if test['test_best']: |
|
|
213 |
ckpt_paths = [x for x in os.listdir(cfg.work_dir) if 'best' in x] |
|
|
214 |
ckpt_paths = [x for x in ckpt_paths if x.endswith('.pth')] |
|
|
215 |
if len(ckpt_paths) == 0: |
|
|
216 |
runner.logger.info('Warning: test_best set, but no ckpt found') |
|
|
217 |
test['test_best'] = False |
|
|
218 |
if not test['test_last']: |
|
|
219 |
return |
|
|
220 |
elif len(ckpt_paths) > 1: |
|
|
221 |
epoch_ids = [ |
|
|
222 |
int(x.split('epoch_')[-1][:-4]) for x in ckpt_paths |
|
|
223 |
] |
|
|
224 |
best_ckpt_path = ckpt_paths[np.argmax(epoch_ids)] |
|
|
225 |
else: |
|
|
226 |
best_ckpt_path = ckpt_paths[0] |
|
|
227 |
if best_ckpt_path: |
|
|
228 |
best_ckpt_path = osp.join(cfg.work_dir, best_ckpt_path) |
|
|
229 |
|
|
|
230 |
test_dataset = build_dataset(cfg.data.test, dict(test_mode=True)) |
|
|
231 |
gpu_collect = cfg.get('evaluation', {}).get('gpu_collect', False) |
|
|
232 |
tmpdir = cfg.get('evaluation', {}).get('tmpdir', |
|
|
233 |
osp.join(cfg.work_dir, 'tmp')) |
|
|
234 |
dataloader_setting = dict( |
|
|
235 |
videos_per_gpu=cfg.data.get('videos_per_gpu', 1), |
|
|
236 |
workers_per_gpu=cfg.data.get('workers_per_gpu', 1), |
|
|
237 |
persistent_workers=cfg.data.get('persistent_workers', False), |
|
|
238 |
num_gpus=len(cfg.gpu_ids), |
|
|
239 |
dist=distributed, |
|
|
240 |
shuffle=False) |
|
|
241 |
dataloader_setting = dict(dataloader_setting, |
|
|
242 |
**cfg.data.get('test_dataloader', {})) |
|
|
243 |
|
|
|
244 |
test_dataloader = build_dataloader(test_dataset, **dataloader_setting) |
|
|
245 |
|
|
|
246 |
names, ckpts = [], [] |
|
|
247 |
|
|
|
248 |
if test['test_last']: |
|
|
249 |
names.append('last') |
|
|
250 |
ckpts.append(None) |
|
|
251 |
if test['test_best'] and best_ckpt_path is not None: |
|
|
252 |
names.append('best') |
|
|
253 |
ckpts.append(best_ckpt_path) |
|
|
254 |
|
|
|
255 |
for name, ckpt in zip(names, ckpts): |
|
|
256 |
if ckpt is not None: |
|
|
257 |
runner.load_checkpoint(ckpt) |
|
|
258 |
|
|
|
259 |
outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir, |
|
|
260 |
gpu_collect) |
|
|
261 |
rank, _ = get_dist_info() |
|
|
262 |
if rank == 0: |
|
|
263 |
out = osp.join(cfg.work_dir, f'{name}_pred.pkl') |
|
|
264 |
test_dataset.dump_results(outputs, out) |
|
|
265 |
|
|
|
266 |
eval_cfg = cfg.get('evaluation', {}) |
|
|
267 |
for key in [ |
|
|
268 |
'interval', 'tmpdir', 'start', 'gpu_collect', |
|
|
269 |
'save_best', 'rule', 'by_epoch', 'broadcast_bn_buffers' |
|
|
270 |
]: |
|
|
271 |
eval_cfg.pop(key, None) |
|
|
272 |
|
|
|
273 |
eval_res = test_dataset.evaluate(outputs, **eval_cfg) |
|
|
274 |
runner.logger.info(f'Testing results of the {name} checkpoint') |
|
|
275 |
for metric_name, val in eval_res.items(): |
|
|
276 |
runner.logger.info(f'{metric_name}: {val:.04f}') |