|
a |
|
b/trainner_baseline.py |
|
|
1 |
from model.Models import * |
|
|
2 |
from dataprocess import * |
|
|
3 |
import loss.losses as losses |
|
|
4 |
from metrics import * |
|
|
5 |
import torch.optim as optim |
|
|
6 |
import time |
|
|
7 |
import numpy as np |
|
|
8 |
import os |
|
|
9 |
import torch |
|
|
10 |
from config import Config |
|
|
11 |
import shutil |
|
|
12 |
from tqdm import tqdm |
|
|
13 |
import imageio |
|
|
14 |
import math |
|
|
15 |
from bisect import bisect_right |
|
|
16 |
|
|
|
17 |
config = Config() |
|
|
18 |
|
|
|
19 |
torch.cuda.set_device(config.gpu) |
|
|
20 |
|
|
|
21 |
model_name = config.arch |
|
|
22 |
if not os.path.isdir('result'): |
|
|
23 |
os.mkdir('result') |
|
|
24 |
if config.resume is False: |
|
|
25 |
with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: |
|
|
26 |
f.seek(0) |
|
|
27 |
f.truncate() |
|
|
28 |
model = U_Net() |
|
|
29 |
model.cuda() |
|
|
30 |
best_dice = 0 # best test accuracy |
|
|
31 |
start_epoch = 0 # start from epoch 0 or last checkpoint epoch |
|
|
32 |
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=(0.5, 0.999)) |
|
|
33 |
dataloader, dataloader_val = get_dataloader(config, batchsize=config.batch_size, mode='row') # 64 |
|
|
34 |
criterion = losses.init_loss('BCE_logit').cuda() |
|
|
35 |
|
|
|
36 |
if config.resume: |
|
|
37 |
# Load checkpoint. |
|
|
38 |
print('==> Resuming from checkpoint..') |
|
|
39 |
if config.evaluate: |
|
|
40 |
checkpoint = torch.load('./checkpoint/' + str(model_name) + '_best.pth.tar') |
|
|
41 |
else: |
|
|
42 |
checkpoint = torch.load('./checkpoint/' + str(model_name) + '.pth.tar') |
|
|
43 |
model.load_state_dict(checkpoint['model']) |
|
|
44 |
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
45 |
best_dice = checkpoint['dice'] |
|
|
46 |
start_epoch = config.epochs |
|
|
47 |
|
|
|
48 |
def adjust_lr(optimizer, epoch, eta_max=0.0001, eta_min=0.): |
|
|
49 |
cur_lr = 0. |
|
|
50 |
if config.lr_type == 'SGDR': |
|
|
51 |
i = int(math.log2(epoch / config.sgdr_t + 1)) |
|
|
52 |
T_cur = epoch - config.sgdr_t * (2 ** (i) - 1) |
|
|
53 |
T_i = (config.sgdr_t * 2 ** i) |
|
|
54 |
|
|
|
55 |
cur_lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * T_cur / T_i)) |
|
|
56 |
|
|
|
57 |
elif config.lr_type == 'multistep': |
|
|
58 |
cur_lr = config.learning_rate * 0.1 ** bisect_right(config.milestones, epoch) |
|
|
59 |
|
|
|
60 |
for param_group in optimizer.param_groups: |
|
|
61 |
param_group['lr'] = cur_lr |
|
|
62 |
return cur_lr |
|
|
63 |
|
|
|
64 |
def train(epoch): |
|
|
65 |
model.train() |
|
|
66 |
train_loss = 0 |
|
|
67 |
|
|
|
68 |
start_time = time.time() |
|
|
69 |
for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader): |
|
|
70 |
iter_start_time = time.time() |
|
|
71 |
inputs = inputs.cuda() |
|
|
72 |
lungs = lungs.cuda() |
|
|
73 |
medias = medias.cuda() |
|
|
74 |
targets_i = targets_i.cuda() |
|
|
75 |
targets_u = targets_u.cuda() |
|
|
76 |
targets_s = targets_s.cuda() |
|
|
77 |
|
|
|
78 |
outputs = model(medias) |
|
|
79 |
|
|
|
80 |
outputs_sig = torch.sigmoid(outputs) |
|
|
81 |
|
|
|
82 |
loss_seg = criterion(outputs_sig, targets_u) |
|
|
83 |
|
|
|
84 |
loss_all = loss_seg |
|
|
85 |
|
|
|
86 |
optimizer.zero_grad() |
|
|
87 |
loss_all.backward() |
|
|
88 |
optimizer.step() |
|
|
89 |
|
|
|
90 |
train_loss += loss_all.item() |
|
|
91 |
|
|
|
92 |
print('Epoch:{}\t batch_idx:{}/All_batch:{}\t duration:{:.3f}\t loss_all:{:.3f}' |
|
|
93 |
.format(epoch, batch_idx, len(dataloader), time.time()-iter_start_time, loss_all.item())) |
|
|
94 |
iter_start_time = time.time() |
|
|
95 |
print('Epoch:{0}\t duration:{1:.3f}\ttrain_loss:{2:.6f}'.format(epoch, time.time()-start_time, train_loss/len(dataloader))) |
|
|
96 |
|
|
|
97 |
with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: |
|
|
98 |
f.write('Epoch:{0}\t duration:{1:.3f}\t learning_rate:{2:.6f}\t train_loss:{3:.4f}' |
|
|
99 |
.format(epoch, time.time()-start_time, config.learning_rate, train_loss/len(dataloader))) |
|
|
100 |
|
|
|
101 |
def test(epoch): |
|
|
102 |
global best_dice |
|
|
103 |
model.eval() |
|
|
104 |
dices_all = [] |
|
|
105 |
ious_all = [] |
|
|
106 |
nsds_all = [] |
|
|
107 |
with torch.no_grad(): |
|
|
108 |
for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader_val): |
|
|
109 |
inputs = inputs.cuda() |
|
|
110 |
medias = medias.cuda() |
|
|
111 |
targets_i = targets_i.cuda() |
|
|
112 |
targets_u = targets_u.cuda() |
|
|
113 |
targets_s = targets_s.cuda() |
|
|
114 |
|
|
|
115 |
outputs = model(medias) |
|
|
116 |
|
|
|
117 |
outputs_final_sig = torch.sigmoid(outputs) |
|
|
118 |
|
|
|
119 |
dices_all = meandice(outputs_final_sig, targets_u, dices_all) |
|
|
120 |
ious_all = meandIoU(outputs_final_sig, targets_u, ious_all) |
|
|
121 |
nsds_all = meanNSD(outputs_final_sig, targets_u, nsds_all) |
|
|
122 |
|
|
|
123 |
print('Epoch:{}\tbatch_idx:{}/All_batch:{}\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}' |
|
|
124 |
.format(epoch, batch_idx, len(dataloader_val), np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all)))) |
|
|
125 |
with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f: |
|
|
126 |
f.write('\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}'.format(np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all)))+'\n') |
|
|
127 |
|
|
|
128 |
# Save checkpoint. |
|
|
129 |
if config.resume is False: |
|
|
130 |
dice = np.mean(np.array(dices_all)) |
|
|
131 |
print('Test accuracy: ', dice) |
|
|
132 |
state = { |
|
|
133 |
'model': model.state_dict(), |
|
|
134 |
'dice': dice, |
|
|
135 |
'epoch': epoch, |
|
|
136 |
'optimizer': optimizer.state_dict() |
|
|
137 |
} |
|
|
138 |
if not os.path.isdir('checkpoint'): |
|
|
139 |
os.mkdir('checkpoint') |
|
|
140 |
torch.save(state, './checkpoint/'+str(model_name)+'.pth.tar') |
|
|
141 |
|
|
|
142 |
is_best = False |
|
|
143 |
if best_dice < dice: |
|
|
144 |
best_dice = dice |
|
|
145 |
is_best = True |
|
|
146 |
|
|
|
147 |
if is_best: |
|
|
148 |
shutil.copyfile('./checkpoint/' + str(model_name) + '.pth.tar', |
|
|
149 |
'./checkpoint/' + str(model_name) + '_best.pth.tar') |
|
|
150 |
print('Save Successfully') |
|
|
151 |
print('------------------------------------------------------------------------') |
|
|
152 |
|
|
|
153 |
if __name__ == '__main__': |
|
|
154 |
|
|
|
155 |
if config.resume: |
|
|
156 |
test(start_epoch) |
|
|
157 |
else: |
|
|
158 |
for epoch in tqdm(range(start_epoch, config.epochs)): |
|
|
159 |
train(epoch) |
|
|
160 |
test(epoch) |