Diff of /tools/train.py [000000] .. [6d389a]

Switch to side-by-side view

--- a
+++ b/tools/train.py
@@ -0,0 +1,243 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import copy
+import multiprocessing as mp
+import os
+import os.path as osp
+import platform
+import time
+import warnings
+
+import cv2
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist, set_random_seed
+from mmcv.utils import get_git_hash
+
+from mmaction import __version__
+from mmaction.apis import init_random_seed, train_model
+from mmaction.datasets import build_dataset
+from mmaction.models import build_model
+from mmaction.utils import collect_env, get_root_logger, register_module_hooks
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+def setup_multi_processes(cfg):
+    # set multi-process start method as `fork` to speed up the training
+    if platform.system() != 'Windows':
+        mp_start_method = cfg.get('mp_start_method', 'fork')
+        mp.set_start_method(mp_start_method)
+
+    # disable opencv multithreading to avoid system being overloaded
+    opencv_num_threads = cfg.get('opencv_num_threads', 0)
+    cv2.setNumThreads(opencv_num_threads)
+
+    # setup OMP threads
+    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
+    if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
+        omp_num_threads = 1
+        warnings.warn(
+            f'Setting OMP_NUM_THREADS environment variable for each process '
+            f'to be {omp_num_threads} in default, to avoid your system being '
+            f'overloaded, please further tune the variable for optimal '
+            f'performance in your application as needed.')
+        os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+    # setup MKL threads
+    if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
+        mkl_num_threads = 1
+        warnings.warn(
+            f'Setting MKL_NUM_THREADS environment variable for each process '
+            f'to be {mkl_num_threads} in default, to avoid your system being '
+            f'overloaded, please further tune the variable for optimal '
+            f'performance in your application as needed.')
+        os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train a recognizer')
+    parser.add_argument('config', help='train config file path')
+    parser.add_argument('--work-dir', help='the dir to save logs and models')
+    parser.add_argument(
+        '--resume-from', help='the checkpoint file to resume from')
+    parser.add_argument(
+        '--validate',
+        action='store_true',
+        help='whether to evaluate the checkpoint during training')
+    parser.add_argument(
+        '--test-last',
+        action='store_true',
+        help='whether to test the checkpoint after training')
+    parser.add_argument(
+        '--test-best',
+        action='store_true',
+        help=('whether to test the best checkpoint (if applicable) after '
+              'training'))
+    group_gpus = parser.add_mutually_exclusive_group()
+    group_gpus.add_argument(
+        '--gpus',
+        type=int,
+        help='number of gpus to use '
+        '(only applicable to non-distributed training)')
+    group_gpus.add_argument(
+        '--gpu-ids',
+        type=int,
+        nargs='+',
+        help='ids of gpus to use '
+        '(only applicable to non-distributed training)')
+    parser.add_argument('--seed', type=int, default=None, help='random seed')
+    parser.add_argument(
+        '--deterministic',
+        action='store_true',
+        help='whether to set deterministic options for CUDNN backend.')
+    parser.add_argument(
+        '--cfg-options',
+        nargs='+',
+        action=DictAction,
+        default={},
+        help='override some settings in the used config, the key-value pair '
+        'in xxx=yyy format will be merged into config file. For example, '
+        "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'")
+    parser.add_argument(
+        '--launcher',
+        choices=['none', 'pytorch', 'slurm', 'mpi'],
+        default='none',
+        help='job launcher')
+    parser.add_argument('--local_rank', type=int, default=0)
+    args = parser.parse_args()
+    if 'LOCAL_RANK' not in os.environ:
+        os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    cfg = Config.fromfile(args.config)
+
+    cfg.merge_from_dict(args.cfg_options)
+
+    # set multi-process settings
+    setup_multi_processes(cfg)
+
+    # set cudnn_benchmark
+    if cfg.get('cudnn_benchmark', False):
+        torch.backends.cudnn.benchmark = True
+
+    # work_dir is determined in this priority:
+    # CLI > config file > default (base filename)
+    if args.work_dir is not None:
+        # update configs according to CLI args if args.work_dir is not None
+        cfg.work_dir = args.work_dir
+    elif cfg.get('work_dir', None) is None:
+        # use config filename as default work_dir if cfg.work_dir is None
+        cfg.work_dir = osp.join('./work_dirs',
+                                osp.splitext(osp.basename(args.config))[0])
+    if args.resume_from is not None:
+        cfg.resume_from = args.resume_from
+    if args.gpu_ids is not None:
+        cfg.gpu_ids = args.gpu_ids
+    else:
+        cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+
+    # init distributed env first, since logger depends on the dist info.
+    if args.launcher == 'none':
+        distributed = False
+    else:
+        distributed = True
+        init_dist(args.launcher, **cfg.dist_params)
+        _, world_size = get_dist_info()
+        cfg.gpu_ids = range(world_size)
+
+    # The flag is used to determine whether it is omnisource training
+    cfg.setdefault('omnisource', False)
+
+    # The flag is used to register module's hooks
+    cfg.setdefault('module_hooks', [])
+
+    # create work_dir
+    mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+    # dump config
+    cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+    # init logger before other steps
+    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+    log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+    logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+    # init the meta dict to record some important information such as
+    # environment info and seed, which will be logged
+    meta = dict()
+    # log env info
+    env_info_dict = collect_env()
+    env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
+    dash_line = '-' * 60 + '\n'
+    logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+                dash_line)
+    meta['env_info'] = env_info
+
+    # log some basic info
+    logger.info(f'Distributed training: {distributed}')
+    logger.info(f'Config: {cfg.pretty_text}')
+
+    # set random seeds
+    seed = init_random_seed(args.seed)
+    logger.info(f'Set random seed to {seed}, '
+                f'deterministic: {args.deterministic}')
+    set_random_seed(seed, deterministic=args.deterministic)
+
+    cfg.seed = seed
+    meta['seed'] = seed
+    meta['config_name'] = osp.basename(args.config)
+    meta['work_dir'] = osp.basename(cfg.work_dir.rstrip('/\\'))
+
+    model = build_model(
+        cfg.model,
+        train_cfg=cfg.get('train_cfg'),
+        test_cfg=cfg.get('test_cfg'))
+    #print(model)
+    print("$$$$NUM", count_parameters(model))
+
+    if len(cfg.module_hooks) > 0:
+        register_module_hooks(model, cfg.module_hooks)
+
+    if cfg.omnisource:
+        # If omnisource flag is set, cfg.data.train should be a list
+        assert isinstance(cfg.data.train, list)
+        datasets = [build_dataset(dataset) for dataset in cfg.data.train]
+    else:
+        datasets = [build_dataset(cfg.data.train)]
+
+    if len(cfg.workflow) == 2:
+        # For simplicity, omnisource is not compatible with val workflow,
+        # we recommend you to use `--validate`
+        assert not cfg.omnisource
+        if args.validate:
+            warnings.warn('val workflow is duplicated with `--validate`, '
+                          'it is recommended to use `--validate`. see '
+                          'https://github.com/open-mmlab/mmaction2/pull/123')
+        val_dataset = copy.deepcopy(cfg.data.val)
+        datasets.append(build_dataset(val_dataset))
+    if cfg.checkpoint_config is not None:
+        # save mmaction version, config file content and class names in
+        # checkpoints as meta data
+        cfg.checkpoint_config.meta = dict(
+            mmaction_version=__version__ + get_git_hash(digits=7),
+            config=cfg.pretty_text)
+
+    test_option = dict(test_last=args.test_last, test_best=args.test_best)
+    train_model(
+        model,
+        datasets,
+        cfg,
+        distributed=distributed,
+        validate=args.validate,
+        test=test_option,
+        timestamp=timestamp,
+        meta=meta)
+
+
+if __name__ == '__main__':
+    main()