Diff of /tools/train.py [000000] .. [98e649]

Switch to unified view

a b/tools/train.py
1
import os
2
import argparse
3
import logging
4
import importlib
5
6
import torch
7
import torch.nn as nn
8
import torch.optim as optim
9
import torch.distributed as dist
10
from torch.utils.data import DataLoader
11
import torchvision
12
from tensorboardX import SummaryWriter
13
14
import _init_paths
15
from libs.configs.config_acdc import cfg
16
17
from libs.datasets import AcdcDataset
18
from libs.datasets import joint_augment as joint_augment
19
from libs.datasets import augment as standard_augment
20
from libs.datasets.collate_batch import BatchCollator
21
# from libs.losses.df_loss import EuclideanLossWithOHEM
22
# from libs.losses.surface_loss import SurfaceLoss
23
from libs.losses.create_losses import Total_loss
24
import train_utils.train_utils as train_utils
25
from train_utils.train_utils import load_checkpoint
26
from utils.init_net import init_weights
27
from utils.comm import get_rank, synchronize
28
29
30
parser = argparse.ArgumentParser(description="arg parser")
31
parser.add_argument("--local_rank", type=int, default=0, required=True, help="device_ids of DistributedDataParallel")
32
parser.add_argument("--batch_size", type=int, default=32, required=False, help="batch size for training")
33
parser.add_argument("--epochs", type=int, default=50, required=False, help="Number of epochs to train for")
34
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader')
35
parser.add_argument("--ckpt_save_interval", type=int, default=5, help="number of training epochs")
36
parser.add_argument('--output_dir', type=str, default=None, help='specify an output directory if needed')
37
parser.add_argument('--mgpus', type=str, default=None, help='whether to use multiple gpu')
38
parser.add_argument("--ckpt", type=str, default=None, help="continue training from this checkpoint")
39
parser.add_argument('--train_with_eval', action='store_true', default=False, help='whether to train with evaluation')
40
args = parser.parse_args()
41
42
FILE_DIR = os.path.dirname(os.path.abspath(__file__))
43
44
if args.mgpus is not None:
45
    os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus
46
47
def create_logger(log_file, dist_rank):
48
    if dist_rank > 0:
49
        logger = logging.getLogger(__name__)
50
        logger.setLevel(logging.WARNING)
51
        return logger
52
    log_format = '%(asctime)s  %(levelname)5s  %(message)s'
53
    logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file)
54
    console = logging.StreamHandler()
55
    console.setLevel(logging.DEBUG)
56
    console.setFormatter(logging.Formatter(log_format))
57
    logging.getLogger(__name__).addHandler(console)
58
    return logging.getLogger(__name__)
59
60
def create_dataloader(logger):
61
    train_joint_transform = joint_augment.Compose([
62
                            joint_augment.To_PIL_Image(),
63
                            joint_augment.RandomAffine(0,translate=(0.125, 0.125)),
64
                            joint_augment.RandomRotate((-180,180)),
65
                            joint_augment.FixResize(256)
66
                            ])
67
    transform = standard_augment.Compose([
68
                    standard_augment.to_Tensor(),
69
                    standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])])
70
    target_transform = standard_augment.Compose([
71
                        standard_augment.to_Tensor()])
72
73
    if cfg.DATASET.NAME == 'acdc':
74
        train_set = AcdcDataset(data_list=cfg.DATASET.TRAIN_LIST,
75
                                df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM,
76
                                boundary=cfg.DATASET.BOUNDARY,
77
                                joint_augment=train_joint_transform,
78
                                augment=transform, target_augment=target_transform)
79
80
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set,
81
                            num_replicas=dist.get_world_size(), rank=dist.get_rank())
82
    train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=True,
83
                              num_workers=args.workers, shuffle=False, sampler=train_sampler,
84
                              collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED,
85
                                                        boundary=cfg.DATASET.BOUNDARY))
86
    
87
    if args.train_with_eval:
88
        eval_transform = joint_augment.Compose([
89
                         joint_augment.To_PIL_Image(),
90
                         joint_augment.FixResize(256),
91
                         joint_augment.To_Tensor()])
92
        evalImg_transform = standard_augment.Compose([
93
                            standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])])
94
95
        if cfg.DATASET.NAME == 'acdc':
96
            test_set = AcdcDataset(data_list=cfg.DATASET.TEST_LIST,
97
                                df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM,
98
                                boundary=cfg.DATASET.BOUNDARY,
99
                                joint_augment=eval_transform,
100
                                augment=evalImg_transform)
101
102
        test_sampler = torch.utils.data.distributed.DistributedSampler(test_set,
103
                            num_replicas=dist.get_world_size(), rank=dist.get_rank())
104
        test_loader = DataLoader(test_set, batch_size=args.batch_size, pin_memory=True,
105
                                 num_workers=args.workers, shuffle=False, sampler=test_sampler,
106
                                 collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED,
107
                                                           boundary=cfg.DATASET.BOUNDARY))
108
    else:
109
        test_loader = None
110
    
111
    return train_loader, test_loader
112
113
def create_optimizer(model):
114
    if cfg.TRAIN.OPTIMIZER == "adam":
115
        optimizer = optim.Adam(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
116
    elif cfg.TRAIN.OPTIMIZER == "sgd":
117
        optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY,
118
                              momentum=cfg.TRAIN.MOMENTUM)
119
    else:
120
        raise NotImplementedError
121
    return optimizer
122
123
def create_scheduler(model, optimizer, total_steps, last_epoch):
124
    def lr_lbmd(cur_epoch):
125
        cur_decay = 1
126
        for decay_step in cfg.TRAIN.DECAY_STEP_LIST:
127
            if cur_epoch >= decay_step:
128
                cur_decay = cur_decay * cfg.TRAIN.LR_DECAY
129
        return max(cur_decay, cfg.TRAIN.LR_CLIP / cfg.TRAIN.LR)
130
131
    lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch)
132
    return lr_scheduler
133
134
def create_model(cfg):
135
    network = cfg.TRAIN.NET
136
137
    module = 'libs.network.' + network[:network.rfind('.')] 
138
    model = network[network.rfind('.')+1:]
139
    
140
    mod = importlib.import_module(module)
141
    mod_func = importlib.import_module('libs.network.train_functions')
142
    net_func = getattr(mod, model)
143
144
    net = net_func(num_class=cfg.DATASET.NUM_CLASS)
145
    if network == 'unet.U_Net':
146
        train_func = getattr(mod_func, 'model_fn_decorator')
147
    elif network == 'unet_df.U_NetDF':
148
        net = net_func(selfeat=cfg.MODEL.SELFEATURE, num_class=cfg.DATASET.NUM_CLASS, shift_n=cfg.MODEL.SHIFT_N, auxseg=cfg.MODEL.AUXSEG)
149
        train_func = getattr(mod_func, 'model_DF_decorator')
150
151
    return net, train_func
152
153
def train():
154
    torch.cuda.set_device(args.local_rank)
155
    dist.init_process_group(backend="nccl", init_method="env://")
156
    synchronize()
157
    
158
    # create dataloader & network & optimizer
159
    model, model_fn_decorator = create_model(cfg)
160
    init_weights(model, init_type='kaiming')
161
    # model.to('cuda')
162
    model.cuda()
163
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
164
165
    root_result_dir = args.output_dir
166
    os.makedirs(root_result_dir, exist_ok=True)
167
168
    log_file = os.path.join(root_result_dir, "log_train.txt")
169
    logger = create_logger(log_file, get_rank())
170
    logger.info("**********************Start logging**********************")
171
172
    # log to file
173
    gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
174
    logger.info("CUDA_VISIBLE_DEVICES=%s" % gpu_list)
175
176
    for key, val in vars(args).items():
177
        logger.info("{:16} {}".format(key, val))
178
    
179
    logger.info("***********************config infos**********************")
180
    for key, val in vars(cfg).items():
181
        logger.info("{:16} {}".format(key, val))
182
    
183
    # log tensorboard
184
    if get_rank() == 0:
185
        tb_log = SummaryWriter(log_dir=os.path.join(root_result_dir, "tensorboard"))
186
    else:
187
        tb_log = None
188
189
190
    train_loader, test_loader = create_dataloader(logger)
191
192
    optimizer = create_optimizer(model)
193
194
    # load checkpoint if it is possible
195
    start_epoch = it = best_res = 0
196
    last_epoch = -1
197
    if args.ckpt is not None:
198
        pure_model = model.module if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) else model
199
        it, start_epoch, best_res = load_checkpoint(pure_model, optimizer, args.ckpt, logger)
200
        last_epoch = start_epoch + 1
201
    
202
    lr_scheduler = create_scheduler(model, optimizer, total_steps=len(train_loader)*args.epochs,
203
                                    last_epoch=last_epoch)
204
205
    if cfg.DATASET.DF_USED:
206
        criterion = Total_loss(boundary=cfg.DATASET.BOUNDARY)
207
    else:
208
        criterion = nn.CrossEntropyLoss()
209
210
211
    # start training
212
    logger.info('**********************Start training**********************')
213
    ckpt_dir = os.path.join(root_result_dir, "ckpt")
214
    os.makedirs(ckpt_dir, exist_ok=True)
215
    trainer = train_utils.Trainer(model,
216
                                  model_fn=model_fn_decorator(),
217
                                  criterion=criterion,
218
                                  optimizer=optimizer,
219
                                  ckpt_dir=ckpt_dir,
220
                                  lr_scheduler=lr_scheduler,
221
                                  model_fn_eval=model_fn_decorator(),
222
                                  tb_log=tb_log,
223
                                  logger=logger,
224
                                  eval_frequency=1,
225
                                  grad_norm_clip=cfg.TRAIN.GRAD_NORM_CLIP,
226
                                  cfg=cfg)
227
    
228
    trainer.train(start_it=it,
229
                  start_epoch=start_epoch,
230
                  n_epochs=args.epochs,
231
                  train_loader=train_loader,
232
                  test_loader=test_loader,
233
                  ckpt_save_interval=args.ckpt_save_interval,
234
                  lr_scheduler_each_iter=False,
235
                  best_res=best_res)
236
237
    logger.info('**********************End training**********************')
238
239
240
# python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 20 --mgpus 2,3 --output_dir logs/... --train_with_eval
241
if __name__ == "__main__":
242
    train()
243
244