a b/semseg/train.py
1
import os
2
import time
3
4
import numpy as np
5
import torch
6
7
from semseg.loss import get_multi_dice_loss
8
from config.config import LEARNING_RATE_REDUCTION_FACTOR
9
from semseg.utils import multi_dice_coeff
10
11
12
def train_model(net, optimizer, train_data, config, device=None, logs_folder=None):
13
14
    print('Start training...')
15
    net = net.to(device)
16
    # train loop
17
    for epoch in range(config.epochs):
18
19
        epoch_start_time = time.time()
20
        running_loss = 0.0
21
22
        # lower learning rate
23
        if epoch == config.low_lr_epoch:
24
            for param_group in optimizer.param_groups:
25
                config.lr = config.lr / LEARNING_RATE_REDUCTION_FACTOR
26
                param_group['lr'] = config.lr
27
28
        # switch to train mode
29
        net.train()
30
31
        for i, data in enumerate(train_data):
32
33
            inputs, labels = data['t1']['data'], data['label']['data']
34
            if config.cuda:
35
                inputs, labels = inputs.cuda(), labels.cuda()
36
37
            # forward pass
38
            outputs = net(inputs)
39
40
            # get multi dice loss
41
            loss = get_multi_dice_loss(outputs, labels, device=device)
42
43
            # empty gradients, perform backward pass and update weights
44
            optimizer.zero_grad()
45
            loss.backward()
46
            optimizer.step()
47
48
            # save and print statistics
49
            running_loss += loss.data
50
51
        epoch_end_time = time.time()
52
        epoch_elapsed_time = epoch_end_time - epoch_start_time
53
54
        # print statistics
55
        print('  [Epoch {:04d}] - Train dice loss: {:.4f} - Time: {:.1f}'
56
              .format(epoch + 1, running_loss / (i + 1), epoch_elapsed_time))
57
58
        # switch to eval mode
59
        net.eval()
60
61
        # only validate every 'val_epochs' epochs
62
        if epoch % config.val_epochs == 0:
63
            if logs_folder is not None:
64
                checkpoint_path = os.path.join(logs_folder, 'model_epoch_{:04d}.pht'.format(epoch))
65
                torch.save(net.state_dict(), checkpoint_path)
66
67
    print('Training ended!')
68
    return net
69
70
71
def val_model(net, val_data, config, device=None):
72
73
    print("Start Validation...")
74
    net = net.to(device)
75
    # val loop
76
    multi_dices = list()
77
    with torch.no_grad():
78
        net.eval()
79
        for i, data in enumerate(val_data):
80
            print("Iter {} on {}".format(i+1,len(val_data)))
81
82
            inputs, labels = data['t1']['data'], data['label']['data']
83
            if config.cuda: inputs, labels = inputs.cuda(), labels.cuda()
84
85
            # forward pass
86
            outputs = net(inputs)
87
            outputs = torch.argmax(outputs, dim=1)  #     B x Z x Y x X
88
            outputs_np = outputs.data.cpu().numpy() #     B x Z x Y x X
89
            labels_np = labels.data.cpu().numpy()   # B x 1 x Z x Y x X
90
            labels_np = labels_np[:,0]              #     B x Z x Y x X
91
92
            multi_dice = multi_dice_coeff(labels_np,outputs_np,config.num_outs)
93
            multi_dices.append(multi_dice)
94
    multi_dices_np = np.array(multi_dices)
95
    mean_multi_dice = np.mean(multi_dices_np)
96
    std_multi_dice = np.std(multi_dices_np)
97
    print("Multi-Dice: {:.4f} +/- {:.4f}".format(mean_multi_dice,std_multi_dice))
98
    return multi_dices, mean_multi_dice, std_multi_dice