--- a +++ b/opengait/main.py @@ -0,0 +1,73 @@ + +import os +import argparse +import torch +import torch.nn as nn +from modeling import models +from utils import config_loader, get_ddp_module, init_seeds, params_count, get_msg_mgr + +parser = argparse.ArgumentParser(description='Main program for opengait.') +parser.add_argument('--local_rank', type=int, default=0, + help="passed by torch.distributed.launch module") +parser.add_argument('--local-rank', type=int, default=0, + help="passed by torch.distributed.launch module, for pytorch >=2.0") +parser.add_argument('--cfgs', type=str, + default='config/default.yaml', help="path of config file") +parser.add_argument('--phase', default='train', + choices=['train', 'test'], help="choose train or test phase") +parser.add_argument('--log_to_file', action='store_true', + help="log to file, default path is: output/<dataset>/<model>/<save_name>/<logs>/<Datetime>.txt") +parser.add_argument('--iter', default=0, help="iter to restore") +opt = parser.parse_args() + + +def initialization(cfgs, training): + msg_mgr = get_msg_mgr() + engine_cfg = cfgs['trainer_cfg'] if training else cfgs['evaluator_cfg'] + output_path = os.path.join('output/', cfgs['data_cfg']['dataset_name'], + cfgs['model_cfg']['model'], engine_cfg['save_name']) + if training: + msg_mgr.init_manager(output_path, opt.log_to_file, engine_cfg['log_iter'], + engine_cfg['restore_hint'] if isinstance(engine_cfg['restore_hint'], (int)) else 0) + else: + msg_mgr.init_logger(output_path, opt.log_to_file) + + msg_mgr.log_info(engine_cfg) + + seed = torch.distributed.get_rank() + init_seeds(seed) + + +def run_model(cfgs, training): + msg_mgr = get_msg_mgr() + model_cfg = cfgs['model_cfg'] + msg_mgr.log_info(model_cfg) + Model = getattr(models, model_cfg['model']) + model = Model(cfgs, training) + if training and cfgs['trainer_cfg']['sync_BN']: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + if cfgs['trainer_cfg']['fix_BN']: + model.fix_BN() + model = get_ddp_module(model, cfgs['trainer_cfg']['find_unused_parameters']) + msg_mgr.log_info(params_count(model)) + msg_mgr.log_info("Model Initialization Finished!") + + if training: + Model.run_train(model) + else: + Model.run_test(model) + + +if __name__ == '__main__': + torch.distributed.init_process_group('nccl', init_method='env://') + if torch.distributed.get_world_size() != torch.cuda.device_count(): + raise ValueError("Expect number of available GPUs({}) equals to the world size({}).".format( + torch.cuda.device_count(), torch.distributed.get_world_size())) + cfgs = config_loader(opt.cfgs) + if opt.iter != 0: + cfgs['evaluator_cfg']['restore_hint'] = int(opt.iter) + cfgs['trainer_cfg']['restore_hint'] = int(opt.iter) + + training = (opt.phase == 'train') + initialization(cfgs, training) + run_model(cfgs, training)