Diff of /train.py [000000] .. [bd7f9c]

Switch to unified view

a b/train.py
1
import numpy as np
2
import os
3
from tqdm import tqdm
4
import time
5
import torch
6
import torch.nn as nn
7
import torch.optim as optim
8
from torch.autograd import Variable
9
from data.dataset import get_train_val_loader, inverse_normalize, get_test_loader
10
from model import UNet
11
from utils.Config import opt
12
from utils.vis_tool import Visualizer
13
from utils.eval_tool import compute_iou, save_pred_result
14
import utils.array_tool as at
15
16
def train(model, train_loader, criterion, epoch, vis):
17
    model.train()
18
    batch_loss = 0
19
    for batch_idx, sample_batched in enumerate(train_loader):
20
        data = sample_batched['image']
21
        target = sample_batched['mask']
22
        data, target = Variable(data.type(opt.dtype)), Variable(target.type(opt.dtype))
23
        optimizer.zero_grad()
24
        output = model(data)
25
        loss = criterion(output, target)
26
        loss.backward()
27
        optimizer.step()
28
        batch_loss += loss.data[0]
29
        if (batch_idx+1) % opt.plot_every == 0:
30
            ori_img_ = inverse_normalize(at.tonumpy(data[0]))
31
            target_ = at.tonumpy(target[0])
32
            pred_ = at.tonumpy(output[0])
33
            vis.img('gt_img', ori_img_)
34
            vis.img('gt_mask', target_)
35
            vis.img('pred_mask', (pred_ >= 0.5).astype(np.float32))
36
37
    batch_loss /= (batch_idx+1)
38
    print('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss))
39
    with open('logs.txt', 'a') as file:
40
        file.write('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss) + '\n')
41
    vis.plot('train loss', batch_loss)
42
43
def val(model, val_loader, criterion, epoch, vis):
44
    model.eval()
45
    batch_loss = 0
46
    avg_iou = 0
47
    for batch_idx, sample_batched in enumerate(val_loader):
48
        data = sample_batched['image']
49
        target = sample_batched['mask']
50
        data, target = Variable(data.type(opt.dtype), volatile=True), Variable(target.type(opt.dtype), volatile=True)
51
        output = model.forward(data)
52
        loss = criterion(output, target)
53
        batch_loss += loss.data[0]
54
        avg_iou += compute_iou(pred_masks=at.tonumpy(output >= 0.5).astype(np.float32), gt_masks=target)
55
56
    batch_loss /= (batch_idx+1)
57
    avg_iou /= len(val_loader.dataset)
58
59
    print('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss), ', avg_iou: ', avg_iou)
60
    with open('logs.txt', 'a') as file:
61
        file.write('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss) + ', avg_iou: ' + str(avg_iou) + '\n')
62
63
    vis.plot('val loss', batch_loss)
64
    vis.plot('validation average IOU', avg_iou)
65
    return avg_iou
66
67
# train and validation
68
def run(model, train_loader, val_loader, criterion, vis):
69
    best_iou = 0
70
    for epoch in range(1, opt.epochs+1):
71
        train(model, train_loader, criterion, epoch, vis)
72
        avg_iou = val(model, val_loader, criterion, epoch, vis)
73
        if avg_iou > best_iou:
74
            best_iou = avg_iou
75
            if opt.save_model:
76
                save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M')
77
                torch.save(model.state_dict(), save_path)
78
79
    if opt.save_model:
80
        save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M')
81
        torch.save(model.state_dict(), save_path)
82
83
# make prediction
84
def run_test(model, test_loader):
85
    pred_masks = []
86
    img_ids = []
87
    images = []
88
    for batch_idx, sample_batched in tqdm(enumerate(test_loader)):
89
        data, img_id = sample_batched['image'], sample_batched['img_id']
90
        data = Variable(data.type(opt.dtype), volatile=True)
91
        output = model.forward(data)
92
        # output = (output > 0.5)
93
        output = at.tonumpy(output)
94
        for i in range(0, output.shape[0]):
95
            pred_mask = np.squeeze(output[i])
96
            id = img_id[i]
97
            pred_mask = (pred_mask >= 0.5).astype(np.float32)
98
            pred_masks.append(pred_mask)
99
            img_ids.append(id)
100
            ori_img_ = inverse_normalize(at.tonumpy(data[i]))
101
            images.append(ori_img_)
102
103
    return img_ids, images, pred_masks
104
105
if __name__ == '__main__':
106
    """Train Unet model"""
107
    model = UNet(input_channels=1, nclasses=1)
108
    if opt.is_train:
109
        # split all data to train and validation, set split = True
110
        train_loader, val_loader = get_train_val_loader(opt.root_dir, batch_size=opt.batch_size, val_ratio=0.15,
111
                                                        shuffle=True, num_workers=4, pin_memory=False)
112
113
        optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
114
        criterion = nn.BCELoss()
115
        vis = Visualizer(env=opt.env)
116
117
        if opt.is_cuda:
118
            model.cuda()
119
            criterion.cuda()
120
            if opt.n_gpu > 1:
121
                model = nn.DataParallel(model)
122
123
        run(model, train_loader, val_loader, criterion, vis)
124
    else:
125
        if opt.is_cuda:
126
            model.cuda()
127
            if opt.n_gpu > 1:
128
                model = nn.DataParallel(model)
129
        test_loader = get_test_loader(batch_size=20, shuffle=True,
130
                                      num_workers=opt.num_workers,
131
                                      pin_memory=opt.pin_memory)
132
        # load the model and run test
133
        model.load_state_dict(torch.load(os.path.join(opt.checkpoint_dir, 'RSNA_UNet_0.895_09210122')))
134
135
        img_ids, images, pred_masks = run_test(model, test_loader)
136
137
        save_pred_result(img_ids, images, pred_masks)