Diff of /CaraNet/train_blood.py [000000] .. [6f3ba0]

Switch to unified view

a b/CaraNet/train_blood.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Thu Jul 29 17:41:30 2021
4
5
@author: angelou
6
"""
7
8
import torch
9
from torch.autograd import Variable
10
import os
11
import argparse
12
from datetime import datetime
13
from utils.dataloader import get_loader,test_dataset
14
from utils.utils import clip_gradient, adjust_lr, AvgMeter
15
import torch.nn.functional as F
16
import numpy as np
17
from torchstat import stat
18
from CaraNet import caranet
19
20
21
22
def structure_loss(pred, mask):
23
    
24
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
25
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
26
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
27
28
    pred = torch.sigmoid(pred)
29
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
30
    union = ((pred + mask)*weit).sum(dim=(2, 3))
31
    wiou = 1 - (inter + 1)/(union - inter+1)
32
    
33
    return (wbce + wiou).mean()
34
35
36
37
38
39
def test(model, path):
40
    
41
    ##### put your data_path of TestDataSet/Kvasir here #####
42
    data_path = path
43
    #########################################################
44
    
45
    model.eval()
46
    image_root = '{}/images/'.format(data_path)
47
    gt_root = '{}/masks/'.format(data_path)
48
    test_loader = test_dataset(image_root, gt_root, 512)
49
    b=0.0
50
    print('[test_size]',test_loader.size)
51
    for i in range(test_loader.size):
52
        image, gt, name = test_loader.load_data()
53
        gt = np.asarray(gt, np.float32)
54
        gt /= (gt.max() + 1e-8)
55
        image = image.cuda()
56
        
57
        res5,res3,res2,res1 = model(image)
58
        res = res5
59
        res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
60
        res = res.sigmoid().data.cpu().numpy().squeeze()
61
        res = (res - res.min()) / (res.max() - res.min() + 1e-8)
62
        
63
        input = res
64
        target = np.array(gt)
65
        N = gt.shape
66
        smooth = 1
67
        input_flat = np.reshape(input,(-1))
68
        target_flat = np.reshape(target,(-1))
69
 
70
        intersection = (input_flat*target_flat)
71
        
72
        loss =  (2 * intersection.sum() + smooth) / (input.sum() + target.sum() + smooth)
73
74
        a =  '{:.4f}'.format(loss)
75
        a = float(a)
76
        b = b + a
77
        
78
    return b/60
79
80
81
82
def train(train_loader, model, optimizer, epoch, test_path):
83
    model.train()
84
    # ---- multi-scale training ----
85
    size_rates = [0.75, 1, 1.25]
86
    loss_record1, loss_record2, loss_record3, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
87
    for i, pack in enumerate(train_loader, start=1):
88
        for rate in size_rates:
89
            optimizer.zero_grad()
90
            # ---- data prepare ----
91
            images, gts = pack
92
            images = Variable(images).cuda()
93
            gts = Variable(gts).cuda()
94
            # ---- rescale ----
95
            trainsize = int(round(opt.trainsize*rate/32)*32)
96
            if rate != 1:
97
                images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
98
                gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
99
            # ---- forward ----
100
            lateral_map_5,lateral_map_3,lateral_map_2,lateral_map_1 = model(images)
101
            # ---- loss function ----
102
            loss5 = structure_loss(lateral_map_5, gts)
103
            loss3 = structure_loss(lateral_map_3, gts)
104
            loss2 = structure_loss(lateral_map_2, gts)
105
            loss1 = structure_loss(lateral_map_1, gts)
106
            
107
            
108
            loss = loss5 +loss3 + loss2 + loss1
109
            # ---- backward ----
110
            loss.backward()
111
            clip_gradient(optimizer, opt.clip)
112
            optimizer.step()
113
            # ---- recording loss ----
114
            if rate == 1:
115
                
116
                loss_record5.update(loss5.data, opt.batchsize)
117
                loss_record3.update(loss3.data, opt.batchsize)
118
                loss_record2.update(loss2.data, opt.batchsize)
119
                loss_record1.update(loss1.data, opt.batchsize)
120
        # ---- train visualization ----
121
        if i % 20 == 0 or i == total_step:
122
            print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], '
123
                  ' lateral-5: {:0.4f}], lateral-3: {:0.4f}], lateral-2: {:0.4f}], lateral-1: {:0.4f}]'.
124
                  format(datetime.now(), epoch, opt.epoch, i, total_step,
125
                          loss_record5.show(),loss_record3.show(),loss_record2.show(),loss_record1.show()))
126
    save_path = 'snapshots/{}/'.format(opt.train_save)
127
    os.makedirs(save_path, exist_ok=True)
128
    
129
    
130
    
131
    
132
    
133
    if (epoch+1) % 1 == 0:
134
        meandice = test(model,test_path)
135
        
136
        fp = open('log/log.txt','a')
137
        fp.write(str(meandice)+'\n')
138
        fp.close()
139
        
140
        fp = open('log/best.txt','r')
141
        best = fp.read()
142
        fp.close()
143
        
144
        if meandice > float(best):
145
            fp = open('log/best.txt','w')
146
            fp.write(str(meandice))
147
            fp.close()
148
            # best = meandice
149
            fp = open('log/best.txt','r')
150
            best = fp.read()
151
            fp.close()
152
            torch.save(model.state_dict(), save_path + 'CaraNet-best.pth' )
153
            print('[Saving Snapshot:]', save_path + 'CaraNet-best.pth',meandice,'[best:]',best)
154
            
155
156
if __name__ == '__main__':
157
    parser = argparse.ArgumentParser()
158
    
159
    parser.add_argument('--epoch', type=int,
160
                        default=10, help='epoch number')
161
    
162
    parser.add_argument('--lr', type=float,
163
                        default=1e-4, help='learning rate')
164
    
165
    parser.add_argument('--optimizer', type=str,
166
                        default='Adam', help='choosing optimizer Adam or SGD')
167
    
168
    parser.add_argument('--augmentation',
169
                        default=False, help='choose to do random flip rotation')
170
    
171
    parser.add_argument('--batchsize', type=int,
172
                        default=6, help='training batch size')
173
    
174
    parser.add_argument('--trainsize', type=int,
175
                        default=352, help='training dataset size')
176
    
177
    parser.add_argument('--clip', type=float,
178
                        default=0.5, help='gradient clipping margin')
179
    
180
    parser.add_argument('--decay_rate', type=float,
181
                        default=0.1, help='decay rate of learning rate')
182
    
183
    parser.add_argument('--decay_epoch', type=int,
184
                        default=50, help='every n epochs decay learning rate')
185
    
186
    parser.add_argument('--train_path', type=str,
187
                        default='/home/data/spleen_blood/CaraNet/TrainDataset/train/', help='path to train dataset')
188
    
189
    parser.add_argument('--test_path', type=str,
190
                        default='/home/data/spleen_blood/CaraNet/TestDataset/test/' , help='path to testing Kvasir dataset')
191
    
192
    parser.add_argument('--train_save', type=str,
193
                        default='')
194
    
195
    opt = parser.parse_args()
196
197
    # ---- build models ----
198
    torch.cuda.set_device(4)  # set your gpu device
199
    model = caranet().cuda()
200
    # ---- flops and params ----
201
202
    # from utils.utils import CalParams
203
    # x = torch.randn(1, 3, 352, 352).cuda()
204
    # CalParams(model, x)
205
206
    params = model.parameters()
207
    
208
    if opt.optimizer == 'Adam':
209
        optimizer = torch.optim.Adam(params, opt.lr)
210
    else:
211
        optimizer = torch.optim.SGD(params, opt.lr, weight_decay = 1e-4, momentum = 0.9)
212
        
213
    print(optimizer)
214
215
    image_root = '{}/image/'.format(opt.train_path)
216
    gt_root = '{}/mask/'.format(opt.train_path)
217
218
    train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize, augmentation = opt.augmentation)
219
    total_step = len(train_loader)
220
221
    print("#"*20, "Start Training", "#"*20)
222
223
    for epoch in range(1, opt.epoch):
224
        adjust_lr(optimizer, opt.lr, epoch, 0.1, 200)
225
        train(train_loader, model, optimizer, epoch, opt.test_path)
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251