|
a |
|
b/CaraNet/train_blood.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
""" |
|
|
3 |
Created on Thu Jul 29 17:41:30 2021 |
|
|
4 |
|
|
|
5 |
@author: angelou |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
from torch.autograd import Variable |
|
|
10 |
import os |
|
|
11 |
import argparse |
|
|
12 |
from datetime import datetime |
|
|
13 |
from utils.dataloader import get_loader,test_dataset |
|
|
14 |
from utils.utils import clip_gradient, adjust_lr, AvgMeter |
|
|
15 |
import torch.nn.functional as F |
|
|
16 |
import numpy as np |
|
|
17 |
from torchstat import stat |
|
|
18 |
from CaraNet import caranet |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def structure_loss(pred, mask): |
|
|
23 |
|
|
|
24 |
weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) |
|
|
25 |
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') |
|
|
26 |
wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) |
|
|
27 |
|
|
|
28 |
pred = torch.sigmoid(pred) |
|
|
29 |
inter = ((pred * mask)*weit).sum(dim=(2, 3)) |
|
|
30 |
union = ((pred + mask)*weit).sum(dim=(2, 3)) |
|
|
31 |
wiou = 1 - (inter + 1)/(union - inter+1) |
|
|
32 |
|
|
|
33 |
return (wbce + wiou).mean() |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
|
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def test(model, path): |
|
|
40 |
|
|
|
41 |
##### put your data_path of TestDataSet/Kvasir here ##### |
|
|
42 |
data_path = path |
|
|
43 |
######################################################### |
|
|
44 |
|
|
|
45 |
model.eval() |
|
|
46 |
image_root = '{}/images/'.format(data_path) |
|
|
47 |
gt_root = '{}/masks/'.format(data_path) |
|
|
48 |
test_loader = test_dataset(image_root, gt_root, 512) |
|
|
49 |
b=0.0 |
|
|
50 |
print('[test_size]',test_loader.size) |
|
|
51 |
for i in range(test_loader.size): |
|
|
52 |
image, gt, name = test_loader.load_data() |
|
|
53 |
gt = np.asarray(gt, np.float32) |
|
|
54 |
gt /= (gt.max() + 1e-8) |
|
|
55 |
image = image.cuda() |
|
|
56 |
|
|
|
57 |
res5,res3,res2,res1 = model(image) |
|
|
58 |
res = res5 |
|
|
59 |
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) |
|
|
60 |
res = res.sigmoid().data.cpu().numpy().squeeze() |
|
|
61 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8) |
|
|
62 |
|
|
|
63 |
input = res |
|
|
64 |
target = np.array(gt) |
|
|
65 |
N = gt.shape |
|
|
66 |
smooth = 1 |
|
|
67 |
input_flat = np.reshape(input,(-1)) |
|
|
68 |
target_flat = np.reshape(target,(-1)) |
|
|
69 |
|
|
|
70 |
intersection = (input_flat*target_flat) |
|
|
71 |
|
|
|
72 |
loss = (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth) |
|
|
73 |
|
|
|
74 |
a = '{:.4f}'.format(loss) |
|
|
75 |
a = float(a) |
|
|
76 |
b = b + a |
|
|
77 |
|
|
|
78 |
return b/60 |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
|
|
|
82 |
def train(train_loader, model, optimizer, epoch, test_path): |
|
|
83 |
model.train() |
|
|
84 |
# ---- multi-scale training ---- |
|
|
85 |
size_rates = [0.75, 1, 1.25] |
|
|
86 |
loss_record1, loss_record2, loss_record3, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() |
|
|
87 |
for i, pack in enumerate(train_loader, start=1): |
|
|
88 |
for rate in size_rates: |
|
|
89 |
optimizer.zero_grad() |
|
|
90 |
# ---- data prepare ---- |
|
|
91 |
images, gts = pack |
|
|
92 |
images = Variable(images).cuda() |
|
|
93 |
gts = Variable(gts).cuda() |
|
|
94 |
# ---- rescale ---- |
|
|
95 |
trainsize = int(round(opt.trainsize*rate/32)*32) |
|
|
96 |
if rate != 1: |
|
|
97 |
images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) |
|
|
98 |
gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) |
|
|
99 |
# ---- forward ---- |
|
|
100 |
lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 = model(images) |
|
|
101 |
# ---- loss function ---- |
|
|
102 |
loss5 = structure_loss(lateral_map_5, gts) |
|
|
103 |
loss3 = structure_loss(lateral_map_3, gts) |
|
|
104 |
loss2 = structure_loss(lateral_map_2, gts) |
|
|
105 |
loss1 = structure_loss(lateral_map_1, gts) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
loss = loss5 +loss3 + loss2 + loss1 |
|
|
109 |
# ---- backward ---- |
|
|
110 |
loss.backward() |
|
|
111 |
clip_gradient(optimizer, opt.clip) |
|
|
112 |
optimizer.step() |
|
|
113 |
# ---- recording loss ---- |
|
|
114 |
if rate == 1: |
|
|
115 |
|
|
|
116 |
loss_record5.update(loss5.data, opt.batchsize) |
|
|
117 |
loss_record3.update(loss3.data, opt.batchsize) |
|
|
118 |
loss_record2.update(loss2.data, opt.batchsize) |
|
|
119 |
loss_record1.update(loss1.data, opt.batchsize) |
|
|
120 |
# ---- train visualization ---- |
|
|
121 |
if i % 20 == 0 or i == total_step: |
|
|
122 |
print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], ' |
|
|
123 |
' lateral-5: {:0.4f}], lateral-3: {:0.4f}], lateral-2: {:0.4f}], lateral-1: {:0.4f}]'. |
|
|
124 |
format(datetime.now(), epoch, opt.epoch, i, total_step, |
|
|
125 |
loss_record5.show(),loss_record3.show(),loss_record2.show(),loss_record1.show())) |
|
|
126 |
save_path = 'snapshots/{}/'.format(opt.train_save) |
|
|
127 |
os.makedirs(save_path, exist_ok=True) |
|
|
128 |
|
|
|
129 |
|
|
|
130 |
|
|
|
131 |
|
|
|
132 |
|
|
|
133 |
if (epoch+1) % 1 == 0: |
|
|
134 |
meandice = test(model,test_path) |
|
|
135 |
|
|
|
136 |
fp = open('log/log.txt','a') |
|
|
137 |
fp.write(str(meandice)+'\n') |
|
|
138 |
fp.close() |
|
|
139 |
|
|
|
140 |
fp = open('log/best.txt','r') |
|
|
141 |
best = fp.read() |
|
|
142 |
fp.close() |
|
|
143 |
|
|
|
144 |
if meandice > float(best): |
|
|
145 |
fp = open('log/best.txt','w') |
|
|
146 |
fp.write(str(meandice)) |
|
|
147 |
fp.close() |
|
|
148 |
# best = meandice |
|
|
149 |
fp = open('log/best.txt','r') |
|
|
150 |
best = fp.read() |
|
|
151 |
fp.close() |
|
|
152 |
torch.save(model.state_dict(), save_path + 'CaraNet-best.pth' ) |
|
|
153 |
print('[Saving Snapshot:]', save_path + 'CaraNet-best.pth',meandice,'[best:]',best) |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
if __name__ == '__main__': |
|
|
157 |
parser = argparse.ArgumentParser() |
|
|
158 |
|
|
|
159 |
parser.add_argument('--epoch', type=int, |
|
|
160 |
default=10, help='epoch number') |
|
|
161 |
|
|
|
162 |
parser.add_argument('--lr', type=float, |
|
|
163 |
default=1e-4, help='learning rate') |
|
|
164 |
|
|
|
165 |
parser.add_argument('--optimizer', type=str, |
|
|
166 |
default='Adam', help='choosing optimizer Adam or SGD') |
|
|
167 |
|
|
|
168 |
parser.add_argument('--augmentation', |
|
|
169 |
default=False, help='choose to do random flip rotation') |
|
|
170 |
|
|
|
171 |
parser.add_argument('--batchsize', type=int, |
|
|
172 |
default=6, help='training batch size') |
|
|
173 |
|
|
|
174 |
parser.add_argument('--trainsize', type=int, |
|
|
175 |
default=352, help='training dataset size') |
|
|
176 |
|
|
|
177 |
parser.add_argument('--clip', type=float, |
|
|
178 |
default=0.5, help='gradient clipping margin') |
|
|
179 |
|
|
|
180 |
parser.add_argument('--decay_rate', type=float, |
|
|
181 |
default=0.1, help='decay rate of learning rate') |
|
|
182 |
|
|
|
183 |
parser.add_argument('--decay_epoch', type=int, |
|
|
184 |
default=50, help='every n epochs decay learning rate') |
|
|
185 |
|
|
|
186 |
parser.add_argument('--train_path', type=str, |
|
|
187 |
default='/home/data/spleen_blood/CaraNet/TrainDataset/train/', help='path to train dataset') |
|
|
188 |
|
|
|
189 |
parser.add_argument('--test_path', type=str, |
|
|
190 |
default='/home/data/spleen_blood/CaraNet/TestDataset/test/' , help='path to testing Kvasir dataset') |
|
|
191 |
|
|
|
192 |
parser.add_argument('--train_save', type=str, |
|
|
193 |
default='') |
|
|
194 |
|
|
|
195 |
opt = parser.parse_args() |
|
|
196 |
|
|
|
197 |
# ---- build models ---- |
|
|
198 |
torch.cuda.set_device(4) # set your gpu device |
|
|
199 |
model = caranet().cuda() |
|
|
200 |
# ---- flops and params ---- |
|
|
201 |
|
|
|
202 |
# from utils.utils import CalParams |
|
|
203 |
# x = torch.randn(1, 3, 352, 352).cuda() |
|
|
204 |
# CalParams(model, x) |
|
|
205 |
|
|
|
206 |
params = model.parameters() |
|
|
207 |
|
|
|
208 |
if opt.optimizer == 'Adam': |
|
|
209 |
optimizer = torch.optim.Adam(params, opt.lr) |
|
|
210 |
else: |
|
|
211 |
optimizer = torch.optim.SGD(params, opt.lr, weight_decay = 1e-4, momentum = 0.9) |
|
|
212 |
|
|
|
213 |
print(optimizer) |
|
|
214 |
|
|
|
215 |
image_root = '{}/image/'.format(opt.train_path) |
|
|
216 |
gt_root = '{}/mask/'.format(opt.train_path) |
|
|
217 |
|
|
|
218 |
train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize, augmentation = opt.augmentation) |
|
|
219 |
total_step = len(train_loader) |
|
|
220 |
|
|
|
221 |
print("#"*20, "Start Training", "#"*20) |
|
|
222 |
|
|
|
223 |
for epoch in range(1, opt.epoch): |
|
|
224 |
adjust_lr(optimizer, opt.lr, epoch, 0.1, 200) |
|
|
225 |
train(train_loader, model, optimizer, epoch, opt.test_path) |
|
|
226 |
|
|
|
227 |
|
|
|
228 |
|
|
|
229 |
|
|
|
230 |
|
|
|
231 |
|
|
|
232 |
|
|
|
233 |
|
|
|
234 |
|
|
|
235 |
|
|
|
236 |
|
|
|
237 |
|
|
|
238 |
|
|
|
239 |
|
|
|
240 |
|
|
|
241 |
|
|
|
242 |
|
|
|
243 |
|
|
|
244 |
|
|
|
245 |
|
|
|
246 |
|
|
|
247 |
|
|
|
248 |
|
|
|
249 |
|
|
|
250 |
|
|
|
251 |
|