|
a |
|
b/train.py |
|
|
1 |
import numpy as np |
|
|
2 |
import os |
|
|
3 |
from tqdm import tqdm |
|
|
4 |
import time |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
import torch.optim as optim |
|
|
8 |
from torch.autograd import Variable |
|
|
9 |
from data.dataset import get_train_val_loader, inverse_normalize, get_test_loader |
|
|
10 |
from model import UNet |
|
|
11 |
from utils.Config import opt |
|
|
12 |
from utils.vis_tool import Visualizer |
|
|
13 |
from utils.eval_tool import compute_iou, save_pred_result |
|
|
14 |
import utils.array_tool as at |
|
|
15 |
|
|
|
16 |
def train(model, train_loader, criterion, epoch, vis): |
|
|
17 |
model.train() |
|
|
18 |
batch_loss = 0 |
|
|
19 |
for batch_idx, sample_batched in enumerate(train_loader): |
|
|
20 |
data = sample_batched['image'] |
|
|
21 |
target = sample_batched['mask'] |
|
|
22 |
data, target = Variable(data.type(opt.dtype)), Variable(target.type(opt.dtype)) |
|
|
23 |
optimizer.zero_grad() |
|
|
24 |
output = model(data) |
|
|
25 |
loss = criterion(output, target) |
|
|
26 |
loss.backward() |
|
|
27 |
optimizer.step() |
|
|
28 |
batch_loss += loss.data[0] |
|
|
29 |
if (batch_idx+1) % opt.plot_every == 0: |
|
|
30 |
ori_img_ = inverse_normalize(at.tonumpy(data[0])) |
|
|
31 |
target_ = at.tonumpy(target[0]) |
|
|
32 |
pred_ = at.tonumpy(output[0]) |
|
|
33 |
vis.img('gt_img', ori_img_) |
|
|
34 |
vis.img('gt_mask', target_) |
|
|
35 |
vis.img('pred_mask', (pred_ >= 0.5).astype(np.float32)) |
|
|
36 |
|
|
|
37 |
batch_loss /= (batch_idx+1) |
|
|
38 |
print('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss)) |
|
|
39 |
with open('logs.txt', 'a') as file: |
|
|
40 |
file.write('epoch: ' + str(epoch) + ', train loss: ' + str(batch_loss) + '\n') |
|
|
41 |
vis.plot('train loss', batch_loss) |
|
|
42 |
|
|
|
43 |
def val(model, val_loader, criterion, epoch, vis): |
|
|
44 |
model.eval() |
|
|
45 |
batch_loss = 0 |
|
|
46 |
avg_iou = 0 |
|
|
47 |
for batch_idx, sample_batched in enumerate(val_loader): |
|
|
48 |
data = sample_batched['image'] |
|
|
49 |
target = sample_batched['mask'] |
|
|
50 |
data, target = Variable(data.type(opt.dtype), volatile=True), Variable(target.type(opt.dtype), volatile=True) |
|
|
51 |
output = model.forward(data) |
|
|
52 |
loss = criterion(output, target) |
|
|
53 |
batch_loss += loss.data[0] |
|
|
54 |
avg_iou += compute_iou(pred_masks=at.tonumpy(output >= 0.5).astype(np.float32), gt_masks=target) |
|
|
55 |
|
|
|
56 |
batch_loss /= (batch_idx+1) |
|
|
57 |
avg_iou /= len(val_loader.dataset) |
|
|
58 |
|
|
|
59 |
print('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss), ', avg_iou: ', avg_iou) |
|
|
60 |
with open('logs.txt', 'a') as file: |
|
|
61 |
file.write('epoch: ' + str(epoch) + ', validation loss: ' + str(batch_loss) + ', avg_iou: ' + str(avg_iou) + '\n') |
|
|
62 |
|
|
|
63 |
vis.plot('val loss', batch_loss) |
|
|
64 |
vis.plot('validation average IOU', avg_iou) |
|
|
65 |
return avg_iou |
|
|
66 |
|
|
|
67 |
# train and validation |
|
|
68 |
def run(model, train_loader, val_loader, criterion, vis): |
|
|
69 |
best_iou = 0 |
|
|
70 |
for epoch in range(1, opt.epochs+1): |
|
|
71 |
train(model, train_loader, criterion, epoch, vis) |
|
|
72 |
avg_iou = val(model, val_loader, criterion, epoch, vis) |
|
|
73 |
if avg_iou > best_iou: |
|
|
74 |
best_iou = avg_iou |
|
|
75 |
if opt.save_model: |
|
|
76 |
save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M') |
|
|
77 |
torch.save(model.state_dict(), save_path) |
|
|
78 |
|
|
|
79 |
if opt.save_model: |
|
|
80 |
save_path = './checkpoints/RSNA_UNet_' + str(round(best_iou, 3)) + '_' + time.strftime('%m%d%H%M') |
|
|
81 |
torch.save(model.state_dict(), save_path) |
|
|
82 |
|
|
|
83 |
# make prediction |
|
|
84 |
def run_test(model, test_loader): |
|
|
85 |
pred_masks = [] |
|
|
86 |
img_ids = [] |
|
|
87 |
images = [] |
|
|
88 |
for batch_idx, sample_batched in tqdm(enumerate(test_loader)): |
|
|
89 |
data, img_id = sample_batched['image'], sample_batched['img_id'] |
|
|
90 |
data = Variable(data.type(opt.dtype), volatile=True) |
|
|
91 |
output = model.forward(data) |
|
|
92 |
# output = (output > 0.5) |
|
|
93 |
output = at.tonumpy(output) |
|
|
94 |
for i in range(0, output.shape[0]): |
|
|
95 |
pred_mask = np.squeeze(output[i]) |
|
|
96 |
id = img_id[i] |
|
|
97 |
pred_mask = (pred_mask >= 0.5).astype(np.float32) |
|
|
98 |
pred_masks.append(pred_mask) |
|
|
99 |
img_ids.append(id) |
|
|
100 |
ori_img_ = inverse_normalize(at.tonumpy(data[i])) |
|
|
101 |
images.append(ori_img_) |
|
|
102 |
|
|
|
103 |
return img_ids, images, pred_masks |
|
|
104 |
|
|
|
105 |
if __name__ == '__main__': |
|
|
106 |
"""Train Unet model""" |
|
|
107 |
model = UNet(input_channels=1, nclasses=1) |
|
|
108 |
if opt.is_train: |
|
|
109 |
# split all data to train and validation, set split = True |
|
|
110 |
train_loader, val_loader = get_train_val_loader(opt.root_dir, batch_size=opt.batch_size, val_ratio=0.15, |
|
|
111 |
shuffle=True, num_workers=4, pin_memory=False) |
|
|
112 |
|
|
|
113 |
optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) |
|
|
114 |
criterion = nn.BCELoss() |
|
|
115 |
vis = Visualizer(env=opt.env) |
|
|
116 |
|
|
|
117 |
if opt.is_cuda: |
|
|
118 |
model.cuda() |
|
|
119 |
criterion.cuda() |
|
|
120 |
if opt.n_gpu > 1: |
|
|
121 |
model = nn.DataParallel(model) |
|
|
122 |
|
|
|
123 |
run(model, train_loader, val_loader, criterion, vis) |
|
|
124 |
else: |
|
|
125 |
if opt.is_cuda: |
|
|
126 |
model.cuda() |
|
|
127 |
if opt.n_gpu > 1: |
|
|
128 |
model = nn.DataParallel(model) |
|
|
129 |
test_loader = get_test_loader(batch_size=20, shuffle=True, |
|
|
130 |
num_workers=opt.num_workers, |
|
|
131 |
pin_memory=opt.pin_memory) |
|
|
132 |
# load the model and run test |
|
|
133 |
model.load_state_dict(torch.load(os.path.join(opt.checkpoint_dir, 'RSNA_UNet_0.895_09210122'))) |
|
|
134 |
|
|
|
135 |
img_ids, images, pred_masks = run_test(model, test_loader) |
|
|
136 |
|
|
|
137 |
save_pred_result(img_ids, images, pred_masks) |