Diff of /trainner_baseline.py [000000] .. [f77492]

Switch to unified view

a b/trainner_baseline.py
1
from model.Models import *
2
from dataprocess import *
3
import loss.losses as losses
4
from metrics import *
5
import torch.optim as optim
6
import time
7
import numpy as np
8
import os
9
import torch
10
from config import Config
11
import shutil
12
from tqdm import tqdm
13
import imageio
14
import math
15
from bisect import bisect_right
16
17
config = Config()
18
19
torch.cuda.set_device(config.gpu)  
20
21
model_name = config.arch
22
if not os.path.isdir('result'):
23
    os.mkdir('result')
24
if config.resume is False:
25
    with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
26
        f.seek(0)
27
        f.truncate()
28
model = U_Net()
29
model.cuda()
30
best_dice = 0  # best test accuracy
31
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
32
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=(0.5, 0.999))    
33
dataloader, dataloader_val = get_dataloader(config, batchsize=config.batch_size, mode='row')   # 64
34
criterion = losses.init_loss('BCE_logit').cuda()
35
36
if config.resume:
37
    # Load checkpoint.
38
    print('==> Resuming from checkpoint..')
39
    if config.evaluate:
40
        checkpoint = torch.load('./checkpoint/' + str(model_name) + '_best.pth.tar')
41
    else:
42
        checkpoint = torch.load('./checkpoint/' + str(model_name) + '.pth.tar')
43
    model.load_state_dict(checkpoint['model'])
44
    optimizer.load_state_dict(checkpoint['optimizer'])
45
    best_dice = checkpoint['dice']
46
    start_epoch = config.epochs
47
48
def adjust_lr(optimizer, epoch, eta_max=0.0001, eta_min=0.):
49
    cur_lr = 0.
50
    if config.lr_type == 'SGDR':
51
        i = int(math.log2(epoch / config.sgdr_t + 1))
52
        T_cur = epoch - config.sgdr_t * (2 ** (i) - 1)
53
        T_i = (config.sgdr_t * 2 ** i)
54
55
        cur_lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * T_cur / T_i))
56
57
    elif config.lr_type == 'multistep':
58
        cur_lr = config.learning_rate * 0.1 ** bisect_right(config.milestones, epoch)
59
60
    for param_group in optimizer.param_groups:
61
        param_group['lr'] = cur_lr
62
    return cur_lr
63
64
def train(epoch):
65
    model.train()
66
    train_loss = 0
67
68
    start_time = time.time()
69
    for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader):
70
        iter_start_time = time.time()
71
        inputs = inputs.cuda()
72
        lungs = lungs.cuda()
73
        medias = medias.cuda()
74
        targets_i = targets_i.cuda()
75
        targets_u = targets_u.cuda()
76
        targets_s = targets_s.cuda()
77
78
        outputs = model(medias)
79
80
        outputs_sig = torch.sigmoid(outputs)
81
82
        loss_seg = criterion(outputs_sig, targets_u)
83
84
        loss_all = loss_seg
85
86
        optimizer.zero_grad()
87
        loss_all.backward()
88
        optimizer.step()
89
90
        train_loss += loss_all.item()
91
92
        print('Epoch:{}\t batch_idx:{}/All_batch:{}\t duration:{:.3f}\t loss_all:{:.3f}'
93
          .format(epoch, batch_idx, len(dataloader), time.time()-iter_start_time, loss_all.item()))
94
        iter_start_time = time.time()
95
    print('Epoch:{0}\t duration:{1:.3f}\ttrain_loss:{2:.6f}'.format(epoch, time.time()-start_time, train_loss/len(dataloader)))
96
    
97
    with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
98
        f.write('Epoch:{0}\t duration:{1:.3f}\t learning_rate:{2:.6f}\t train_loss:{3:.4f}'
99
          .format(epoch, time.time()-start_time, config.learning_rate, train_loss/len(dataloader)))
100
101
def test(epoch):
102
    global best_dice
103
    model.eval()
104
    dices_all = []
105
    ious_all = []
106
    nsds_all = []
107
    with torch.no_grad():
108
        for batch_idx, (inputs, lungs, medias, targets_u, targets_i, targets_s) in enumerate(dataloader_val):
109
            inputs = inputs.cuda()
110
            medias = medias.cuda()
111
            targets_i = targets_i.cuda()
112
            targets_u = targets_u.cuda()
113
            targets_s = targets_s.cuda()
114
115
            outputs = model(medias)
116
117
            outputs_final_sig = torch.sigmoid(outputs)
118
119
            dices_all = meandice(outputs_final_sig, targets_u, dices_all)
120
            ious_all = meandIoU(outputs_final_sig, targets_u, ious_all)
121
            nsds_all = meanNSD(outputs_final_sig, targets_u, nsds_all)
122
123
            print('Epoch:{}\tbatch_idx:{}/All_batch:{}\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}'
124
            .format(epoch, batch_idx, len(dataloader_val), np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all))))
125
        with open('result/' + str(os.path.basename(__file__).split('.')[0]) + '_' + model_name + '.txt', 'a+') as f:
126
            f.write('\tdice:{:.4f}\tiou:{:.4f}\tnsd:{:.4f}'.format(np.mean(np.array(dices_all)), np.mean(np.array(ious_all)), np.mean(np.array(nsds_all)))+'\n')
127
128
    # Save checkpoint.
129
    if config.resume is False:
130
        dice = np.mean(np.array(dices_all))
131
        print('Test accuracy: ', dice)
132
        state = {
133
            'model': model.state_dict(),
134
            'dice': dice,
135
            'epoch': epoch,
136
            'optimizer': optimizer.state_dict()
137
        }
138
        if not os.path.isdir('checkpoint'):
139
            os.mkdir('checkpoint')
140
        torch.save(state, './checkpoint/'+str(model_name)+'.pth.tar')
141
142
        is_best = False
143
        if best_dice < dice:
144
            best_dice = dice
145
            is_best = True
146
147
        if is_best:
148
            shutil.copyfile('./checkpoint/' + str(model_name) + '.pth.tar',
149
                            './checkpoint/' + str(model_name) + '_best.pth.tar')
150
        print('Save Successfully')
151
        print('------------------------------------------------------------------------')
152
153
if __name__ == '__main__':
154
155
    if config.resume:
156
        test(start_epoch)
157
    else:
158
        for epoch in tqdm(range(start_epoch, config.epochs)):
159
            train(epoch)
160
            test(epoch)