|
a |
|
b/semseg/train.py |
|
|
1 |
import os |
|
|
2 |
import time |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
|
|
|
7 |
from semseg.loss import get_multi_dice_loss |
|
|
8 |
from config.config import LEARNING_RATE_REDUCTION_FACTOR |
|
|
9 |
from semseg.utils import multi_dice_coeff |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
def train_model(net, optimizer, train_data, config, device=None, logs_folder=None): |
|
|
13 |
|
|
|
14 |
print('Start training...') |
|
|
15 |
net = net.to(device) |
|
|
16 |
# train loop |
|
|
17 |
for epoch in range(config.epochs): |
|
|
18 |
|
|
|
19 |
epoch_start_time = time.time() |
|
|
20 |
running_loss = 0.0 |
|
|
21 |
|
|
|
22 |
# lower learning rate |
|
|
23 |
if epoch == config.low_lr_epoch: |
|
|
24 |
for param_group in optimizer.param_groups: |
|
|
25 |
config.lr = config.lr / LEARNING_RATE_REDUCTION_FACTOR |
|
|
26 |
param_group['lr'] = config.lr |
|
|
27 |
|
|
|
28 |
# switch to train mode |
|
|
29 |
net.train() |
|
|
30 |
|
|
|
31 |
for i, data in enumerate(train_data): |
|
|
32 |
|
|
|
33 |
inputs, labels = data['t1']['data'], data['label']['data'] |
|
|
34 |
if config.cuda: |
|
|
35 |
inputs, labels = inputs.cuda(), labels.cuda() |
|
|
36 |
|
|
|
37 |
# forward pass |
|
|
38 |
outputs = net(inputs) |
|
|
39 |
|
|
|
40 |
# get multi dice loss |
|
|
41 |
loss = get_multi_dice_loss(outputs, labels, device=device) |
|
|
42 |
|
|
|
43 |
# empty gradients, perform backward pass and update weights |
|
|
44 |
optimizer.zero_grad() |
|
|
45 |
loss.backward() |
|
|
46 |
optimizer.step() |
|
|
47 |
|
|
|
48 |
# save and print statistics |
|
|
49 |
running_loss += loss.data |
|
|
50 |
|
|
|
51 |
epoch_end_time = time.time() |
|
|
52 |
epoch_elapsed_time = epoch_end_time - epoch_start_time |
|
|
53 |
|
|
|
54 |
# print statistics |
|
|
55 |
print(' [Epoch {:04d}] - Train dice loss: {:.4f} - Time: {:.1f}' |
|
|
56 |
.format(epoch + 1, running_loss / (i + 1), epoch_elapsed_time)) |
|
|
57 |
|
|
|
58 |
# switch to eval mode |
|
|
59 |
net.eval() |
|
|
60 |
|
|
|
61 |
# only validate every 'val_epochs' epochs |
|
|
62 |
if epoch % config.val_epochs == 0: |
|
|
63 |
if logs_folder is not None: |
|
|
64 |
checkpoint_path = os.path.join(logs_folder, 'model_epoch_{:04d}.pht'.format(epoch)) |
|
|
65 |
torch.save(net.state_dict(), checkpoint_path) |
|
|
66 |
|
|
|
67 |
print('Training ended!') |
|
|
68 |
return net |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
def val_model(net, val_data, config, device=None): |
|
|
72 |
|
|
|
73 |
print("Start Validation...") |
|
|
74 |
net = net.to(device) |
|
|
75 |
# val loop |
|
|
76 |
multi_dices = list() |
|
|
77 |
with torch.no_grad(): |
|
|
78 |
net.eval() |
|
|
79 |
for i, data in enumerate(val_data): |
|
|
80 |
print("Iter {} on {}".format(i+1,len(val_data))) |
|
|
81 |
|
|
|
82 |
inputs, labels = data['t1']['data'], data['label']['data'] |
|
|
83 |
if config.cuda: inputs, labels = inputs.cuda(), labels.cuda() |
|
|
84 |
|
|
|
85 |
# forward pass |
|
|
86 |
outputs = net(inputs) |
|
|
87 |
outputs = torch.argmax(outputs, dim=1) # B x Z x Y x X |
|
|
88 |
outputs_np = outputs.data.cpu().numpy() # B x Z x Y x X |
|
|
89 |
labels_np = labels.data.cpu().numpy() # B x 1 x Z x Y x X |
|
|
90 |
labels_np = labels_np[:,0] # B x Z x Y x X |
|
|
91 |
|
|
|
92 |
multi_dice = multi_dice_coeff(labels_np,outputs_np,config.num_outs) |
|
|
93 |
multi_dices.append(multi_dice) |
|
|
94 |
multi_dices_np = np.array(multi_dices) |
|
|
95 |
mean_multi_dice = np.mean(multi_dices_np) |
|
|
96 |
std_multi_dice = np.std(multi_dices_np) |
|
|
97 |
print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice,std_multi_dice)) |
|
|
98 |
return multi_dices, mean_multi_dice, std_multi_dice |