|
a |
|
b/tools/train.py |
|
|
1 |
# Copyright (c) OpenMMLab. All rights reserved. |
|
|
2 |
import argparse |
|
|
3 |
import copy |
|
|
4 |
import multiprocessing as mp |
|
|
5 |
import os |
|
|
6 |
import os.path as osp |
|
|
7 |
import platform |
|
|
8 |
import time |
|
|
9 |
import warnings |
|
|
10 |
|
|
|
11 |
import cv2 |
|
|
12 |
import mmcv |
|
|
13 |
import torch |
|
|
14 |
from mmcv import Config, DictAction |
|
|
15 |
from mmcv.runner import get_dist_info, init_dist, set_random_seed |
|
|
16 |
from mmcv.utils import get_git_hash |
|
|
17 |
|
|
|
18 |
from mmaction import __version__ |
|
|
19 |
from mmaction.apis import init_random_seed, train_model |
|
|
20 |
from mmaction.datasets import build_dataset |
|
|
21 |
from mmaction.models import build_model |
|
|
22 |
from mmaction.utils import collect_env, get_root_logger, register_module_hooks |
|
|
23 |
|
|
|
24 |
def count_parameters(model): |
|
|
25 |
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
26 |
|
|
|
27 |
def setup_multi_processes(cfg): |
|
|
28 |
# set multi-process start method as `fork` to speed up the training |
|
|
29 |
if platform.system() != 'Windows': |
|
|
30 |
mp_start_method = cfg.get('mp_start_method', 'fork') |
|
|
31 |
mp.set_start_method(mp_start_method) |
|
|
32 |
|
|
|
33 |
# disable opencv multithreading to avoid system being overloaded |
|
|
34 |
opencv_num_threads = cfg.get('opencv_num_threads', 0) |
|
|
35 |
cv2.setNumThreads(opencv_num_threads) |
|
|
36 |
|
|
|
37 |
# setup OMP threads |
|
|
38 |
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa |
|
|
39 |
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1): |
|
|
40 |
omp_num_threads = 1 |
|
|
41 |
warnings.warn( |
|
|
42 |
f'Setting OMP_NUM_THREADS environment variable for each process ' |
|
|
43 |
f'to be {omp_num_threads} in default, to avoid your system being ' |
|
|
44 |
f'overloaded, please further tune the variable for optimal ' |
|
|
45 |
f'performance in your application as needed.') |
|
|
46 |
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) |
|
|
47 |
|
|
|
48 |
# setup MKL threads |
|
|
49 |
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: |
|
|
50 |
mkl_num_threads = 1 |
|
|
51 |
warnings.warn( |
|
|
52 |
f'Setting MKL_NUM_THREADS environment variable for each process ' |
|
|
53 |
f'to be {mkl_num_threads} in default, to avoid your system being ' |
|
|
54 |
f'overloaded, please further tune the variable for optimal ' |
|
|
55 |
f'performance in your application as needed.') |
|
|
56 |
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def parse_args(): |
|
|
60 |
parser = argparse.ArgumentParser(description='Train a recognizer') |
|
|
61 |
parser.add_argument('config', help='train config file path') |
|
|
62 |
parser.add_argument('--work-dir', help='the dir to save logs and models') |
|
|
63 |
parser.add_argument( |
|
|
64 |
'--resume-from', help='the checkpoint file to resume from') |
|
|
65 |
parser.add_argument( |
|
|
66 |
'--validate', |
|
|
67 |
action='store_true', |
|
|
68 |
help='whether to evaluate the checkpoint during training') |
|
|
69 |
parser.add_argument( |
|
|
70 |
'--test-last', |
|
|
71 |
action='store_true', |
|
|
72 |
help='whether to test the checkpoint after training') |
|
|
73 |
parser.add_argument( |
|
|
74 |
'--test-best', |
|
|
75 |
action='store_true', |
|
|
76 |
help=('whether to test the best checkpoint (if applicable) after ' |
|
|
77 |
'training')) |
|
|
78 |
group_gpus = parser.add_mutually_exclusive_group() |
|
|
79 |
group_gpus.add_argument( |
|
|
80 |
'--gpus', |
|
|
81 |
type=int, |
|
|
82 |
help='number of gpus to use ' |
|
|
83 |
'(only applicable to non-distributed training)') |
|
|
84 |
group_gpus.add_argument( |
|
|
85 |
'--gpu-ids', |
|
|
86 |
type=int, |
|
|
87 |
nargs='+', |
|
|
88 |
help='ids of gpus to use ' |
|
|
89 |
'(only applicable to non-distributed training)') |
|
|
90 |
parser.add_argument('--seed', type=int, default=None, help='random seed') |
|
|
91 |
parser.add_argument( |
|
|
92 |
'--deterministic', |
|
|
93 |
action='store_true', |
|
|
94 |
help='whether to set deterministic options for CUDNN backend.') |
|
|
95 |
parser.add_argument( |
|
|
96 |
'--cfg-options', |
|
|
97 |
nargs='+', |
|
|
98 |
action=DictAction, |
|
|
99 |
default={}, |
|
|
100 |
help='override some settings in the used config, the key-value pair ' |
|
|
101 |
'in xxx=yyy format will be merged into config file. For example, ' |
|
|
102 |
"'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") |
|
|
103 |
parser.add_argument( |
|
|
104 |
'--launcher', |
|
|
105 |
choices=['none', 'pytorch', 'slurm', 'mpi'], |
|
|
106 |
default='none', |
|
|
107 |
help='job launcher') |
|
|
108 |
parser.add_argument('--local_rank', type=int, default=0) |
|
|
109 |
args = parser.parse_args() |
|
|
110 |
if 'LOCAL_RANK' not in os.environ: |
|
|
111 |
os.environ['LOCAL_RANK'] = str(args.local_rank) |
|
|
112 |
|
|
|
113 |
return args |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def main(): |
|
|
117 |
args = parse_args() |
|
|
118 |
|
|
|
119 |
cfg = Config.fromfile(args.config) |
|
|
120 |
|
|
|
121 |
cfg.merge_from_dict(args.cfg_options) |
|
|
122 |
|
|
|
123 |
# set multi-process settings |
|
|
124 |
setup_multi_processes(cfg) |
|
|
125 |
|
|
|
126 |
# set cudnn_benchmark |
|
|
127 |
if cfg.get('cudnn_benchmark', False): |
|
|
128 |
torch.backends.cudnn.benchmark = True |
|
|
129 |
|
|
|
130 |
# work_dir is determined in this priority: |
|
|
131 |
# CLI > config file > default (base filename) |
|
|
132 |
if args.work_dir is not None: |
|
|
133 |
# update configs according to CLI args if args.work_dir is not None |
|
|
134 |
cfg.work_dir = args.work_dir |
|
|
135 |
elif cfg.get('work_dir', None) is None: |
|
|
136 |
# use config filename as default work_dir if cfg.work_dir is None |
|
|
137 |
cfg.work_dir = osp.join('./work_dirs', |
|
|
138 |
osp.splitext(osp.basename(args.config))[0]) |
|
|
139 |
if args.resume_from is not None: |
|
|
140 |
cfg.resume_from = args.resume_from |
|
|
141 |
if args.gpu_ids is not None: |
|
|
142 |
cfg.gpu_ids = args.gpu_ids |
|
|
143 |
else: |
|
|
144 |
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) |
|
|
145 |
|
|
|
146 |
# init distributed env first, since logger depends on the dist info. |
|
|
147 |
if args.launcher == 'none': |
|
|
148 |
distributed = False |
|
|
149 |
else: |
|
|
150 |
distributed = True |
|
|
151 |
init_dist(args.launcher, **cfg.dist_params) |
|
|
152 |
_, world_size = get_dist_info() |
|
|
153 |
cfg.gpu_ids = range(world_size) |
|
|
154 |
|
|
|
155 |
# The flag is used to determine whether it is omnisource training |
|
|
156 |
cfg.setdefault('omnisource', False) |
|
|
157 |
|
|
|
158 |
# The flag is used to register module's hooks |
|
|
159 |
cfg.setdefault('module_hooks', []) |
|
|
160 |
|
|
|
161 |
# create work_dir |
|
|
162 |
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) |
|
|
163 |
# dump config |
|
|
164 |
cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) |
|
|
165 |
# init logger before other steps |
|
|
166 |
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) |
|
|
167 |
log_file = osp.join(cfg.work_dir, f'{timestamp}.log') |
|
|
168 |
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) |
|
|
169 |
|
|
|
170 |
# init the meta dict to record some important information such as |
|
|
171 |
# environment info and seed, which will be logged |
|
|
172 |
meta = dict() |
|
|
173 |
# log env info |
|
|
174 |
env_info_dict = collect_env() |
|
|
175 |
env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) |
|
|
176 |
dash_line = '-' * 60 + '\n' |
|
|
177 |
logger.info('Environment info:\n' + dash_line + env_info + '\n' + |
|
|
178 |
dash_line) |
|
|
179 |
meta['env_info'] = env_info |
|
|
180 |
|
|
|
181 |
# log some basic info |
|
|
182 |
logger.info(f'Distributed training: {distributed}') |
|
|
183 |
logger.info(f'Config: {cfg.pretty_text}') |
|
|
184 |
|
|
|
185 |
# set random seeds |
|
|
186 |
seed = init_random_seed(args.seed) |
|
|
187 |
logger.info(f'Set random seed to {seed}, ' |
|
|
188 |
f'deterministic: {args.deterministic}') |
|
|
189 |
set_random_seed(seed, deterministic=args.deterministic) |
|
|
190 |
|
|
|
191 |
cfg.seed = seed |
|
|
192 |
meta['seed'] = seed |
|
|
193 |
meta['config_name'] = osp.basename(args.config) |
|
|
194 |
meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\')) |
|
|
195 |
|
|
|
196 |
model = build_model( |
|
|
197 |
cfg.model, |
|
|
198 |
train_cfg=cfg.get('train_cfg'), |
|
|
199 |
test_cfg=cfg.get('test_cfg')) |
|
|
200 |
#print(model) |
|
|
201 |
print("$$$$NUM", count_parameters(model)) |
|
|
202 |
|
|
|
203 |
if len(cfg.module_hooks) > 0: |
|
|
204 |
register_module_hooks(model, cfg.module_hooks) |
|
|
205 |
|
|
|
206 |
if cfg.omnisource: |
|
|
207 |
# If omnisource flag is set, cfg.data.train should be a list |
|
|
208 |
assert isinstance(cfg.data.train, list) |
|
|
209 |
datasets = [build_dataset(dataset) for dataset in cfg.data.train] |
|
|
210 |
else: |
|
|
211 |
datasets = [build_dataset(cfg.data.train)] |
|
|
212 |
|
|
|
213 |
if len(cfg.workflow) == 2: |
|
|
214 |
# For simplicity, omnisource is not compatible with val workflow, |
|
|
215 |
# we recommend you to use `--validate` |
|
|
216 |
assert not cfg.omnisource |
|
|
217 |
if args.validate: |
|
|
218 |
warnings.warn('val workflow is duplicated with `--validate`, ' |
|
|
219 |
'it is recommended to use `--validate`. see ' |
|
|
220 |
'https://github.com/open-mmlab/mmaction2/pull/123') |
|
|
221 |
val_dataset = copy.deepcopy(cfg.data.val) |
|
|
222 |
datasets.append(build_dataset(val_dataset)) |
|
|
223 |
if cfg.checkpoint_config is not None: |
|
|
224 |
# save mmaction version, config file content and class names in |
|
|
225 |
# checkpoints as meta data |
|
|
226 |
cfg.checkpoint_config.meta = dict( |
|
|
227 |
mmaction_version=__version__ + get_git_hash(digits=7), |
|
|
228 |
config=cfg.pretty_text) |
|
|
229 |
|
|
|
230 |
test_option = dict(test_last=args.test_last, test_best=args.test_best) |
|
|
231 |
train_model( |
|
|
232 |
model, |
|
|
233 |
datasets, |
|
|
234 |
cfg, |
|
|
235 |
distributed=distributed, |
|
|
236 |
validate=args.validate, |
|
|
237 |
test=test_option, |
|
|
238 |
timestamp=timestamp, |
|
|
239 |
meta=meta) |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
if __name__ == '__main__': |
|
|
243 |
main() |