|
a |
|
b/tools/train.py |
|
|
1 |
import os |
|
|
2 |
import argparse |
|
|
3 |
import logging |
|
|
4 |
import importlib |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import torch.optim as optim |
|
|
9 |
import torch.distributed as dist |
|
|
10 |
from torch.utils.data import DataLoader |
|
|
11 |
import torchvision |
|
|
12 |
from tensorboardX import SummaryWriter |
|
|
13 |
|
|
|
14 |
import _init_paths |
|
|
15 |
from libs.configs.config_acdc import cfg |
|
|
16 |
|
|
|
17 |
from libs.datasets import AcdcDataset |
|
|
18 |
from libs.datasets import joint_augment as joint_augment |
|
|
19 |
from libs.datasets import augment as standard_augment |
|
|
20 |
from libs.datasets.collate_batch import BatchCollator |
|
|
21 |
# from libs.losses.df_loss import EuclideanLossWithOHEM |
|
|
22 |
# from libs.losses.surface_loss import SurfaceLoss |
|
|
23 |
from libs.losses.create_losses import Total_loss |
|
|
24 |
import train_utils.train_utils as train_utils |
|
|
25 |
from train_utils.train_utils import load_checkpoint |
|
|
26 |
from utils.init_net import init_weights |
|
|
27 |
from utils.comm import get_rank, synchronize |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
parser = argparse.ArgumentParser(description="arg parser") |
|
|
31 |
parser.add_argument("--local_rank", type=int, default=0, required=True, help="device_ids of DistributedDataParallel") |
|
|
32 |
parser.add_argument("--batch_size", type=int, default=32, required=False, help="batch size for training") |
|
|
33 |
parser.add_argument("--epochs", type=int, default=50, required=False, help="Number of epochs to train for") |
|
|
34 |
parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloader') |
|
|
35 |
parser.add_argument("--ckpt_save_interval", type=int, default=5, help="number of training epochs") |
|
|
36 |
parser.add_argument('--output_dir', type=str, default=None, help='specify an output directory if needed') |
|
|
37 |
parser.add_argument('--mgpus', type=str, default=None, help='whether to use multiple gpu') |
|
|
38 |
parser.add_argument("--ckpt", type=str, default=None, help="continue training from this checkpoint") |
|
|
39 |
parser.add_argument('--train_with_eval', action='store_true', default=False, help='whether to train with evaluation') |
|
|
40 |
args = parser.parse_args() |
|
|
41 |
|
|
|
42 |
FILE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
43 |
|
|
|
44 |
if args.mgpus is not None: |
|
|
45 |
os.environ["CUDA_VISIBLE_DEVICES"] = args.mgpus |
|
|
46 |
|
|
|
47 |
def create_logger(log_file, dist_rank): |
|
|
48 |
if dist_rank > 0: |
|
|
49 |
logger = logging.getLogger(__name__) |
|
|
50 |
logger.setLevel(logging.WARNING) |
|
|
51 |
return logger |
|
|
52 |
log_format = '%(asctime)s %(levelname)5s %(message)s' |
|
|
53 |
logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) |
|
|
54 |
console = logging.StreamHandler() |
|
|
55 |
console.setLevel(logging.DEBUG) |
|
|
56 |
console.setFormatter(logging.Formatter(log_format)) |
|
|
57 |
logging.getLogger(__name__).addHandler(console) |
|
|
58 |
return logging.getLogger(__name__) |
|
|
59 |
|
|
|
60 |
def create_dataloader(logger): |
|
|
61 |
train_joint_transform = joint_augment.Compose([ |
|
|
62 |
joint_augment.To_PIL_Image(), |
|
|
63 |
joint_augment.RandomAffine(0,translate=(0.125, 0.125)), |
|
|
64 |
joint_augment.RandomRotate((-180,180)), |
|
|
65 |
joint_augment.FixResize(256) |
|
|
66 |
]) |
|
|
67 |
transform = standard_augment.Compose([ |
|
|
68 |
standard_augment.to_Tensor(), |
|
|
69 |
standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) |
|
|
70 |
target_transform = standard_augment.Compose([ |
|
|
71 |
standard_augment.to_Tensor()]) |
|
|
72 |
|
|
|
73 |
if cfg.DATASET.NAME == 'acdc': |
|
|
74 |
train_set = AcdcDataset(data_list=cfg.DATASET.TRAIN_LIST, |
|
|
75 |
df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM, |
|
|
76 |
boundary=cfg.DATASET.BOUNDARY, |
|
|
77 |
joint_augment=train_joint_transform, |
|
|
78 |
augment=transform, target_augment=target_transform) |
|
|
79 |
|
|
|
80 |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, |
|
|
81 |
num_replicas=dist.get_world_size(), rank=dist.get_rank()) |
|
|
82 |
train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=True, |
|
|
83 |
num_workers=args.workers, shuffle=False, sampler=train_sampler, |
|
|
84 |
collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED, |
|
|
85 |
boundary=cfg.DATASET.BOUNDARY)) |
|
|
86 |
|
|
|
87 |
if args.train_with_eval: |
|
|
88 |
eval_transform = joint_augment.Compose([ |
|
|
89 |
joint_augment.To_PIL_Image(), |
|
|
90 |
joint_augment.FixResize(256), |
|
|
91 |
joint_augment.To_Tensor()]) |
|
|
92 |
evalImg_transform = standard_augment.Compose([ |
|
|
93 |
standard_augment.normalize([cfg.DATASET.MEAN], [cfg.DATASET.STD])]) |
|
|
94 |
|
|
|
95 |
if cfg.DATASET.NAME == 'acdc': |
|
|
96 |
test_set = AcdcDataset(data_list=cfg.DATASET.TEST_LIST, |
|
|
97 |
df_used=cfg.DATASET.DF_USED, df_norm=cfg.DATASET.DF_NORM, |
|
|
98 |
boundary=cfg.DATASET.BOUNDARY, |
|
|
99 |
joint_augment=eval_transform, |
|
|
100 |
augment=evalImg_transform) |
|
|
101 |
|
|
|
102 |
test_sampler = torch.utils.data.distributed.DistributedSampler(test_set, |
|
|
103 |
num_replicas=dist.get_world_size(), rank=dist.get_rank()) |
|
|
104 |
test_loader = DataLoader(test_set, batch_size=args.batch_size, pin_memory=True, |
|
|
105 |
num_workers=args.workers, shuffle=False, sampler=test_sampler, |
|
|
106 |
collate_fn=BatchCollator(size_divisible=32, df_used=cfg.DATASET.DF_USED, |
|
|
107 |
boundary=cfg.DATASET.BOUNDARY)) |
|
|
108 |
else: |
|
|
109 |
test_loader = None |
|
|
110 |
|
|
|
111 |
return train_loader, test_loader |
|
|
112 |
|
|
|
113 |
def create_optimizer(model): |
|
|
114 |
if cfg.TRAIN.OPTIMIZER == "adam": |
|
|
115 |
optimizer = optim.Adam(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY) |
|
|
116 |
elif cfg.TRAIN.OPTIMIZER == "sgd": |
|
|
117 |
optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY, |
|
|
118 |
momentum=cfg.TRAIN.MOMENTUM) |
|
|
119 |
else: |
|
|
120 |
raise NotImplementedError |
|
|
121 |
return optimizer |
|
|
122 |
|
|
|
123 |
def create_scheduler(model, optimizer, total_steps, last_epoch): |
|
|
124 |
def lr_lbmd(cur_epoch): |
|
|
125 |
cur_decay = 1 |
|
|
126 |
for decay_step in cfg.TRAIN.DECAY_STEP_LIST: |
|
|
127 |
if cur_epoch >= decay_step: |
|
|
128 |
cur_decay = cur_decay * cfg.TRAIN.LR_DECAY |
|
|
129 |
return max(cur_decay, cfg.TRAIN.LR_CLIP / cfg.TRAIN.LR) |
|
|
130 |
|
|
|
131 |
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch) |
|
|
132 |
return lr_scheduler |
|
|
133 |
|
|
|
134 |
def create_model(cfg): |
|
|
135 |
network = cfg.TRAIN.NET |
|
|
136 |
|
|
|
137 |
module = 'libs.network.' + network[:network.rfind('.')] |
|
|
138 |
model = network[network.rfind('.')+1:] |
|
|
139 |
|
|
|
140 |
mod = importlib.import_module(module) |
|
|
141 |
mod_func = importlib.import_module('libs.network.train_functions') |
|
|
142 |
net_func = getattr(mod, model) |
|
|
143 |
|
|
|
144 |
net = net_func(num_class=cfg.DATASET.NUM_CLASS) |
|
|
145 |
if network == 'unet.U_Net': |
|
|
146 |
train_func = getattr(mod_func, 'model_fn_decorator') |
|
|
147 |
elif network == 'unet_df.U_NetDF': |
|
|
148 |
net = net_func(selfeat=cfg.MODEL.SELFEATURE, num_class=cfg.DATASET.NUM_CLASS, shift_n=cfg.MODEL.SHIFT_N, auxseg=cfg.MODEL.AUXSEG) |
|
|
149 |
train_func = getattr(mod_func, 'model_DF_decorator') |
|
|
150 |
|
|
|
151 |
return net, train_func |
|
|
152 |
|
|
|
153 |
def train(): |
|
|
154 |
torch.cuda.set_device(args.local_rank) |
|
|
155 |
dist.init_process_group(backend="nccl", init_method="env://") |
|
|
156 |
synchronize() |
|
|
157 |
|
|
|
158 |
# create dataloader & network & optimizer |
|
|
159 |
model, model_fn_decorator = create_model(cfg) |
|
|
160 |
init_weights(model, init_type='kaiming') |
|
|
161 |
# model.to('cuda') |
|
|
162 |
model.cuda() |
|
|
163 |
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) |
|
|
164 |
|
|
|
165 |
root_result_dir = args.output_dir |
|
|
166 |
os.makedirs(root_result_dir, exist_ok=True) |
|
|
167 |
|
|
|
168 |
log_file = os.path.join(root_result_dir, "log_train.txt") |
|
|
169 |
logger = create_logger(log_file, get_rank()) |
|
|
170 |
logger.info("**********************Start logging**********************") |
|
|
171 |
|
|
|
172 |
# log to file |
|
|
173 |
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL' |
|
|
174 |
logger.info("CUDA_VISIBLE_DEVICES=%s" % gpu_list) |
|
|
175 |
|
|
|
176 |
for key, val in vars(args).items(): |
|
|
177 |
logger.info("{:16} {}".format(key, val)) |
|
|
178 |
|
|
|
179 |
logger.info("***********************config infos**********************") |
|
|
180 |
for key, val in vars(cfg).items(): |
|
|
181 |
logger.info("{:16} {}".format(key, val)) |
|
|
182 |
|
|
|
183 |
# log tensorboard |
|
|
184 |
if get_rank() == 0: |
|
|
185 |
tb_log = SummaryWriter(log_dir=os.path.join(root_result_dir, "tensorboard")) |
|
|
186 |
else: |
|
|
187 |
tb_log = None |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
train_loader, test_loader = create_dataloader(logger) |
|
|
191 |
|
|
|
192 |
optimizer = create_optimizer(model) |
|
|
193 |
|
|
|
194 |
# load checkpoint if it is possible |
|
|
195 |
start_epoch = it = best_res = 0 |
|
|
196 |
last_epoch = -1 |
|
|
197 |
if args.ckpt is not None: |
|
|
198 |
pure_model = model.module if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) else model |
|
|
199 |
it, start_epoch, best_res = load_checkpoint(pure_model, optimizer, args.ckpt, logger) |
|
|
200 |
last_epoch = start_epoch + 1 |
|
|
201 |
|
|
|
202 |
lr_scheduler = create_scheduler(model, optimizer, total_steps=len(train_loader)*args.epochs, |
|
|
203 |
last_epoch=last_epoch) |
|
|
204 |
|
|
|
205 |
if cfg.DATASET.DF_USED: |
|
|
206 |
criterion = Total_loss(boundary=cfg.DATASET.BOUNDARY) |
|
|
207 |
else: |
|
|
208 |
criterion = nn.CrossEntropyLoss() |
|
|
209 |
|
|
|
210 |
|
|
|
211 |
# start training |
|
|
212 |
logger.info('**********************Start training**********************') |
|
|
213 |
ckpt_dir = os.path.join(root_result_dir, "ckpt") |
|
|
214 |
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
215 |
trainer = train_utils.Trainer(model, |
|
|
216 |
model_fn=model_fn_decorator(), |
|
|
217 |
criterion=criterion, |
|
|
218 |
optimizer=optimizer, |
|
|
219 |
ckpt_dir=ckpt_dir, |
|
|
220 |
lr_scheduler=lr_scheduler, |
|
|
221 |
model_fn_eval=model_fn_decorator(), |
|
|
222 |
tb_log=tb_log, |
|
|
223 |
logger=logger, |
|
|
224 |
eval_frequency=1, |
|
|
225 |
grad_norm_clip=cfg.TRAIN.GRAD_NORM_CLIP, |
|
|
226 |
cfg=cfg) |
|
|
227 |
|
|
|
228 |
trainer.train(start_it=it, |
|
|
229 |
start_epoch=start_epoch, |
|
|
230 |
n_epochs=args.epochs, |
|
|
231 |
train_loader=train_loader, |
|
|
232 |
test_loader=test_loader, |
|
|
233 |
ckpt_save_interval=args.ckpt_save_interval, |
|
|
234 |
lr_scheduler_each_iter=False, |
|
|
235 |
best_res=best_res) |
|
|
236 |
|
|
|
237 |
logger.info('**********************End training**********************') |
|
|
238 |
|
|
|
239 |
|
|
|
240 |
# python -m torch.distributed.launch --nproc_per_node 2 --master_port $RANDOM tools/train.py --batch_size 20 --mgpus 2,3 --output_dir logs/... --train_with_eval |
|
|
241 |
if __name__ == "__main__": |
|
|
242 |
train() |
|
|
243 |
|
|
|
244 |
|