a b/tools/train.py
1
# Copyright (c) OpenMMLab. All rights reserved.
2
import argparse
3
import copy
4
import os
5
import os.path as osp
6
import time
7
import warnings
8
9
import mmcv
10
import torch
11
from mmcv.cnn.utils import revert_sync_batchnorm
12
from mmcv.runner import get_dist_info, init_dist
13
from mmcv.utils import Config, DictAction, get_git_hash
14
15
from mmseg import __version__
16
from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
17
from mmseg.datasets import build_dataset
18
from mmseg.models import build_segmentor
19
from mmseg.utils import collect_env, get_root_logger
20
21
22
def parse_args():
23
    parser = argparse.ArgumentParser(description='Train a segmentor')
24
    parser.add_argument('config', help='train config file path')
25
    parser.add_argument('--work-dir', help='the dir to save logs and models')
26
    parser.add_argument(
27
        '--load-from', help='the checkpoint file to load weights from')
28
    parser.add_argument(
29
        '--resume-from', help='the checkpoint file to resume from')
30
    parser.add_argument(
31
        '--no-validate',
32
        action='store_true',
33
        help='whether not to evaluate the checkpoint during training')
34
    group_gpus = parser.add_mutually_exclusive_group()
35
    group_gpus.add_argument(
36
        '--gpus',
37
        type=int,
38
        help='number of gpus to use '
39
        '(only applicable to non-distributed training)')
40
    group_gpus.add_argument(
41
        '--gpu-ids',
42
        type=int,
43
        nargs='+',
44
        help='ids of gpus to use '
45
        '(only applicable to non-distributed training)')
46
    parser.add_argument('--seed', type=int, default=None, help='random seed')
47
    parser.add_argument(
48
        '--deterministic',
49
        action='store_true',
50
        help='whether to set deterministic options for CUDNN backend.')
51
    parser.add_argument(
52
        '--options',
53
        nargs='+',
54
        action=DictAction,
55
        help="--options is deprecated in favor of --cfg_options' and it will "
56
        'not be supported in version v0.22.0. Override some settings in the '
57
        'used config, the key-value pair in xxx=yyy format will be merged '
58
        'into config file. If the value to be overwritten is a list, it '
59
        'should be like key="[a,b]" or key=a,b It also allows nested '
60
        'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
61
        'marks are necessary and that no white space is allowed.')
62
    parser.add_argument(
63
        '--cfg-options',
64
        nargs='+',
65
        action=DictAction,
66
        help='override some settings in the used config, the key-value pair '
67
        'in xxx=yyy format will be merged into config file. If the value to '
68
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
69
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
70
        'Note that the quotation marks are necessary and that no white space '
71
        'is allowed.')
72
    parser.add_argument(
73
        '--launcher',
74
        choices=['none', 'pytorch', 'slurm', 'mpi'],
75
        default='none',
76
        help='job launcher')
77
    parser.add_argument('--local_rank', type=int, default=0)
78
    args = parser.parse_args()
79
    if 'LOCAL_RANK' not in os.environ:
80
        os.environ['LOCAL_RANK'] = str(args.local_rank)
81
82
    if args.options and args.cfg_options:
83
        raise ValueError(
84
            '--options and --cfg-options cannot be both '
85
            'specified, --options is deprecated in favor of --cfg-options. '
86
            '--options will not be supported in version v0.22.0.')
87
    if args.options:
88
        warnings.warn('--options is deprecated in favor of --cfg-options. '
89
                      '--options will not be supported in version v0.22.0.')
90
        args.cfg_options = args.options
91
92
    return args
93
94
95
def main():
96
    args = parse_args()
97
98
    cfg = Config.fromfile(args.config)
99
    if args.cfg_options is not None:
100
        cfg.merge_from_dict(args.cfg_options)
101
    # set cudnn_benchmark
102
    if cfg.get('cudnn_benchmark', False):
103
        torch.backends.cudnn.benchmark = True
104
105
    # work_dir is determined in this priority: CLI > segment in file > filename
106
    if args.work_dir is not None:
107
        # update configs according to CLI args if args.work_dir is not None
108
        cfg.work_dir = args.work_dir
109
    elif cfg.get('work_dir', None) is None:
110
        # use config filename as default work_dir if cfg.work_dir is None
111
        cfg.work_dir = osp.join('./work_dirs',
112
                                osp.splitext(osp.basename(args.config))[0])
113
    if args.load_from is not None:
114
        cfg.load_from = args.load_from
115
    if args.resume_from is not None:
116
        cfg.resume_from = args.resume_from
117
    if args.gpu_ids is not None:
118
        cfg.gpu_ids = args.gpu_ids
119
    else:
120
        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
121
122
    # init distributed env first, since logger depends on the dist info.
123
    if args.launcher == 'none':
124
        distributed = False
125
    else:
126
        distributed = True
127
        init_dist(args.launcher, **cfg.dist_params)
128
        # gpu_ids is used to calculate iter when resuming checkpoint
129
        _, world_size = get_dist_info()
130
        cfg.gpu_ids = range(world_size)
131
132
    # create work_dir
133
    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
134
    # dump config
135
    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
136
    # init the logger before other steps
137
    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
138
    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
139
    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
140
141
    # init the meta dict to record some important information such as
142
    # environment info and seed, which will be logged
143
    meta = dict()
144
    # log env info
145
    env_info_dict = collect_env()
146
    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
147
    dash_line = '-' * 60 + '\n'
148
    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
149
                dash_line)
150
    meta['env_info'] = env_info
151
152
    # log some basic info
153
    logger.info(f'Distributed training: {distributed}')
154
    logger.info(f'Config:\n{cfg.pretty_text}')
155
156
    # set random seeds
157
    seed = init_random_seed(args.seed)
158
    logger.info(f'Set random seed to {seed}, '
159
                f'deterministic: {args.deterministic}')
160
    set_random_seed(seed, deterministic=args.deterministic)
161
    cfg.seed = seed
162
    meta['seed'] = seed
163
    meta['exp_name'] = osp.basename(args.config)
164
165
    model = build_segmentor(
166
        cfg.model,
167
        train_cfg=cfg.get('train_cfg'),
168
        test_cfg=cfg.get('test_cfg'))
169
    model.init_weights()
170
171
    # SyncBN is not support for DP
172
    if not distributed:
173
        warnings.warn(
174
            'SyncBN is only supported with DDP. To be compatible with DP, '
175
            'we convert SyncBN to BN. Please use dist_train.sh which can '
176
            'avoid this error.')
177
        model = revert_sync_batchnorm(model)
178
179
    logger.info(model)
180
181
    datasets = [build_dataset(cfg.data.train)]
182
    if len(cfg.workflow) == 2:
183
        val_dataset = copy.deepcopy(cfg.data.val)
184
        val_dataset.pipeline = cfg.data.train.pipeline
185
        datasets.append(build_dataset(val_dataset))
186
    if cfg.checkpoint_config is not None:
187
        # save mmseg version, config file content and class names in
188
        # checkpoints as meta data
189
        cfg.checkpoint_config.meta = dict(
190
            mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
191
            config=cfg.pretty_text,
192
            CLASSES=datasets[0].CLASSES,
193
            PALETTE=datasets[0].PALETTE)
194
    # add an attribute for visualization convenience
195
    model.CLASSES = datasets[0].CLASSES
196
    # passing checkpoint meta for saving best checkpoint
197
    meta.update(cfg.checkpoint_config.meta)
198
    train_segmentor(
199
        model,
200
        datasets,
201
        cfg,
202
        distributed=distributed,
203
        validate=(not args.no_validate),
204
        timestamp=timestamp,
205
        meta=meta)
206
207
208
if __name__ == '__main__':
209
    main()