Diff of /opengait/main.py [000000] .. [fd9ef4]

Switch to unified view

a b/opengait/main.py
1
2
import os
3
import argparse
4
import torch
5
import torch.nn as nn
6
from modeling import models
7
from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr
8
9
parser = argparse.ArgumentParser(description='Main program for opengait.')
10
parser.add_argument('--local_rank', type=int, default=0,
11
                    help="passed by torch.distributed.launch module")
12
parser.add_argument('--local-rank', type=int, default=0,
13
                    help="passed by torch.distributed.launch module, for pytorch >=2.0")
14
parser.add_argument('--cfgs', type=str,
15
                    default='config/default.yaml', help="path of config file")
16
parser.add_argument('--phase', default='train',
17
                    choices=['train', 'test'], help="choose train or test phase")
18
parser.add_argument('--log_to_file', action='store_true',
19
                    help="log to file, default path is: output/<dataset>/<model>/<save_name>/<logs>/<Datetime>.txt")
20
parser.add_argument('--iter', default=0, help="iter to restore")
21
opt = parser.parse_args()
22
23
24
def initialization(cfgs, training):
25
    msg_mgr = get_msg_mgr()
26
    engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg']
27
    output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'],
28
                               cfgs['model_cfg']['model'], engine_cfg['save_name'])
29
    if training:
30
        msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'],
31
                             engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0)
32
    else:
33
        msg_mgr.init_logger(output_path, opt.log_to_file)
34
35
    msg_mgr.log_info(engine_cfg)
36
37
    seed = torch.distributed.get_rank()
38
    init_seeds(seed)
39
40
41
def run_model(cfgs, training):
42
    msg_mgr = get_msg_mgr()
43
    model_cfg = cfgs['model_cfg']
44
    msg_mgr.log_info(model_cfg)
45
    Model = getattr(models, model_cfg['model'])
46
    model = Model(cfgs, training)
47
    if training and cfgs['trainer_cfg']['sync_BN']:
48
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
49
    if cfgs['trainer_cfg']['fix_BN']:
50
        model.fix_BN()
51
    model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters'])
52
    msg_mgr.log_info(params_count(model))
53
    msg_mgr.log_info("Model Initialization Finished!")
54
55
    if training:
56
        Model.run_train(model)
57
    else:
58
        Model.run_test(model)
59
60
61
if __name__ == '__main__':
62
    torch.distributed.init_process_group('nccl', init_method='env://')
63
    if torch.distributed.get_world_size() != torch.cuda.device_count():
64
        raise ValueError("Expect number of available GPUs({}) equals to the world size({}).".format(
65
            torch.cuda.device_count(), torch.distributed.get_world_size()))
66
    cfgs = config_loader(opt.cfgs)
67
    if opt.iter != 0:
68
        cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter)
69
        cfgs['trainer_cfg']['restore_hint'] = int(opt.iter)
70
71
    training = (opt.phase == 'train')
72
    initialization(cfgs, training)
73
    run_model(cfgs, training)