Diff of /tools/train.py [000000] .. [6d389a]

Switch to unified view

a b/tools/train.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import copy
4
import multiprocessing as mp
5
import os
6
import os.path as osp
7
import platform
8
import time
9
import warnings
10
11
import cv2
12
import mmcv
13
import torch
14
from mmcv import Config, DictAction
15
from mmcv.runner import get_dist_info, init_dist, set_random_seed
16
from mmcv.utils import get_git_hash
17
18
from mmaction import __version__
19
from mmaction.apis import init_random_seed, train_model
20
from mmaction.datasets import build_dataset
21
from mmaction.models import build_model
22
from mmaction.utils import collect_env, get_root_logger, register_module_hooks
23
24
def count_parameters(model):
25
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
26
27
def setup_multi_processes(cfg):
28
    # set multi-process start method as `fork` to speed up the training
29
    if platform.system() != 'Windows':
30
        mp_start_method = cfg.get('mp_start_method', 'fork')
31
        mp.set_start_method(mp_start_method)
32
33
    # disable opencv multithreading to avoid system being overloaded
34
    opencv_num_threads = cfg.get('opencv_num_threads', 0)
35
    cv2.setNumThreads(opencv_num_threads)
36
37
    # setup OMP threads
38
    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
39
    if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
40
        omp_num_threads = 1
41
        warnings.warn(
42
            f'Setting OMP_NUM_THREADS environment variable for each process '
43
            f'to be {omp_num_threads} in default, to avoid your system being '
44
            f'overloaded, please further tune the variable for optimal '
45
            f'performance in your application as needed.')
46
        os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
47
48
    # setup MKL threads
49
    if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
50
        mkl_num_threads = 1
51
        warnings.warn(
52
            f'Setting MKL_NUM_THREADS environment variable for each process '
53
            f'to be {mkl_num_threads} in default, to avoid your system being '
54
            f'overloaded, please further tune the variable for optimal '
55
            f'performance in your application as needed.')
56
        os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
57
58
59
def parse_args():
60
    parser = argparse.ArgumentParser(description='Train a recognizer')
61
    parser.add_argument('config', help='train config file path')
62
    parser.add_argument('--work-dir', help='the dir to save logs and models')
63
    parser.add_argument(
64
        '--resume-from', help='the checkpoint file to resume from')
65
    parser.add_argument(
66
        '--validate',
67
        action='store_true',
68
        help='whether to evaluate the checkpoint during training')
69
    parser.add_argument(
70
        '--test-last',
71
        action='store_true',
72
        help='whether to test the checkpoint after training')
73
    parser.add_argument(
74
        '--test-best',
75
        action='store_true',
76
        help=('whether to test the best checkpoint (if applicable) after '
77
              'training'))
78
    group_gpus = parser.add_mutually_exclusive_group()
79
    group_gpus.add_argument(
80
        '--gpus',
81
        type=int,
82
        help='number of gpus to use '
83
        '(only applicable to non-distributed training)')
84
    group_gpus.add_argument(
85
        '--gpu-ids',
86
        type=int,
87
        nargs='+',
88
        help='ids of gpus to use '
89
        '(only applicable to non-distributed training)')
90
    parser.add_argument('--seed', type=int, default=None, help='random seed')
91
    parser.add_argument(
92
        '--deterministic',
93
        action='store_true',
94
        help='whether to set deterministic options for CUDNN backend.')
95
    parser.add_argument(
96
        '--cfg-options',
97
        nargs='+',
98
        action=DictAction,
99
        default={},
100
        help='override some settings in the used config, the key-value pair '
101
        'in xxx=yyy format will be merged into config file. For example, '
102
        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
103
    parser.add_argument(
104
        '--launcher',
105
        choices=['none', 'pytorch', 'slurm', 'mpi'],
106
        default='none',
107
        help='job launcher')
108
    parser.add_argument('--local_rank', type=int, default=0)
109
    args = parser.parse_args()
110
    if 'LOCAL_RANK' not in os.environ:
111
        os.environ['LOCAL_RANK'] = str(args.local_rank)
112
113
    return args
114
115
116
def main():
117
    args = parse_args()
118
119
    cfg = Config.fromfile(args.config)
120
121
    cfg.merge_from_dict(args.cfg_options)
122
123
    # set multi-process settings
124
    setup_multi_processes(cfg)
125
126
    # set cudnn_benchmark
127
    if cfg.get('cudnn_benchmark', False):
128
        torch.backends.cudnn.benchmark = True
129
130
    # work_dir is determined in this priority:
131
    # CLI > config file > default (base filename)
132
    if args.work_dir is not None:
133
        # update configs according to CLI args if args.work_dir is not None
134
        cfg.work_dir = args.work_dir
135
    elif cfg.get('work_dir', None) is None:
136
        # use config filename as default work_dir if cfg.work_dir is None
137
        cfg.work_dir = osp.join('./work_dirs',
138
                                osp.splitext(osp.basename(args.config))[0])
139
    if args.resume_from is not None:
140
        cfg.resume_from = args.resume_from
141
    if args.gpu_ids is not None:
142
        cfg.gpu_ids = args.gpu_ids
143
    else:
144
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
145
146
    # init distributed env first, since logger depends on the dist info.
147
    if args.launcher == 'none':
148
        distributed = False
149
    else:
150
        distributed = True
151
        init_dist(args.launcher, **cfg.dist_params)
152
        _, world_size = get_dist_info()
153
        cfg.gpu_ids = range(world_size)
154
155
    # The flag is used to determine whether it is omnisource training
156
    cfg.setdefault('omnisource', False)
157
158
    # The flag is used to register module's hooks
159
    cfg.setdefault('module_hooks', [])
160
161
    # create work_dir
162
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
163
    # dump config
164
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
165
    # init logger before other steps
166
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
167
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
168
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
169
170
    # init the meta dict to record some important information such as
171
    # environment info and seed, which will be logged
172
    meta = dict()
173
    # log env info
174
    env_info_dict = collect_env()
175
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
176
    dash_line = '-' * 60 + '\n'
177
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
178
                dash_line)
179
    meta['env_info'] = env_info
180
181
    # log some basic info
182
    logger.info(f'Distributed training: {distributed}')
183
    logger.info(f'Config: {cfg.pretty_text}')
184
185
    # set random seeds
186
    seed = init_random_seed(args.seed)
187
    logger.info(f'Set random seed to {seed}, '
188
                f'deterministic: {args.deterministic}')
189
    set_random_seed(seed, deterministic=args.deterministic)
190
191
    cfg.seed = seed
192
    meta['seed'] = seed
193
    meta['config_name'] = osp.basename(args.config)
194
    meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))
195
196
    model = build_model(
197
        cfg.model,
198
        train_cfg=cfg.get('train_cfg'),
199
        test_cfg=cfg.get('test_cfg'))
200
    #print(model)
201
    print("$$$$NUM", count_parameters(model))
202
203
    if len(cfg.module_hooks) > 0:
204
        register_module_hooks(model, cfg.module_hooks)
205
206
    if cfg.omnisource:
207
        # If omnisource flag is set, cfg.data.train should be a list
208
        assert isinstance(cfg.data.train, list)
209
        datasets = [build_dataset(dataset) for dataset in cfg.data.train]
210
    else:
211
        datasets = [build_dataset(cfg.data.train)]
212
213
    if len(cfg.workflow) == 2:
214
        # For simplicity, omnisource is not compatible with val workflow,
215
        # we recommend you to use `--validate`
216
        assert not cfg.omnisource
217
        if args.validate:
218
            warnings.warn('val workflow is duplicated with `--validate`, '
219
                          'it is recommended to use `--validate`. see '
220
                          'https://github.com/open-mmlab/mmaction2/pull/123')
221
        val_dataset = copy.deepcopy(cfg.data.val)
222
        datasets.append(build_dataset(val_dataset))
223
    if cfg.checkpoint_config is not None:
224
        # save mmaction version, config file content and class names in
225
        # checkpoints as meta data
226
        cfg.checkpoint_config.meta = dict(
227
            mmaction_version=__version__ + get_git_hash(digits=7),
228
            config=cfg.pretty_text)
229
230
    test_option = dict(test_last=args.test_last, test_best=args.test_best)
231
    train_model(
232
        model,
233
        datasets,
234
        cfg,
235
        distributed=distributed,
236
        validate=args.validate,
237
        test=test_option,
238
        timestamp=timestamp,
239
        meta=meta)
240
241
242
if __name__ == '__main__':
243
    main()