a b/train_infer.py
1
""" Training augmented model """
2
import os
3
import torch
4
import torch.nn as nn
5
import numpy as np
6
from tensorboardX import SummaryWriter
7
from ptflops import get_model_complexity_info
8
import utils
9
import data_generator_3D as data_generator_3D
10
import time
11
import SimpleITK as sitk
12
import sys
13
from config import TrainConfig
14
from model import LCOVNet
15
from apex import amp
16
17
18
config = TrainConfig()
19
20
device = torch.device("cuda")
21
22
# tensorboard
23
writer = SummaryWriter(log_dir=os.path.join(config.path, "tb"))
24
writer.add_text('config', config.as_markdown(), 0)
25
26
logger = utils.get_logger(os.path.join(config.path, "{}.log".format(config.name)))
27
config.print_params(logger.info)
28
29
def main():
30
    logger.info("Logger is set - training start")
31
32
    # set default gpu device id
33
    torch.cuda.set_device(config.gpus[0])
34
35
    # set seed
36
    np.random.seed(config.seed)
37
    torch.manual_seed(config.seed)
38
    torch.cuda.manual_seed_all(config.seed)
39
40
    torch.backends.cudnn.benchmark = True
41
42
43
    criterion = utils.log_loss().to(device)
44
    d = torch.device(type='cuda', index=config.gpus[0])
45
    model = LCOVNet(config.input_channels, config.n_classes).to(device=d)
46
    with torch.cuda.device(config.gpus[0]):
47
        net = model
48
        macs, params = get_model_complexity_info(net, (1, 240, 160, 48), as_strings=True,
49
                                                 print_per_layer_stat=True, verbose=True)
50
        logger.info("{:<30}  {:<8}".format('Computational complexity: ', macs))
51
        logger.info("{:<30}  {:<8}".format('Number of parameters: ', params))
52
53
    # model size
54
    mb_params = utils.param_size(model)
55
    logger.info("Model size = {:.3f} MB".format(mb_params))
56
    # weights optimizer
57
    optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum,
58
                                weight_decay=config.weight_decay)
59
    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
60
61
    train_loader = data_generator_3D.Covid19TrainSet()
62
    valid_loader = data_generator_3D.Covid19EvalSet()
63
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config.epochs)
64
65
    best_dice = 0.
66
    # training loop
67
    summ_writer = SummaryWriter(config.training_summary_dir)
68
    for epoch in range(config.epochs):
69
70
        # training
71
        train(train_loader, model, optimizer, criterion, epoch, summ_writer)
72
        lr_scheduler.step()
73
        # validation
74
        cur_step = (epoch+1) * len(train_loader)
75
        mean_dice = validate(valid_loader, model, criterion, epoch, summ_writer, best_dice)
76
77
        # save
78
        if best_dice < mean_dice:
79
            best_dice = mean_dice
80
            is_best = True
81
        else:
82
            is_best = False
83
        utils.save_checkpoint(model, config.path, is_best)
84
        print("")
85
86
    logger.info("Final best Dice = {:.4%}".format(best_dice))
87
    utils.save_results(best_dice, config.path)
88
    summ_writer.close()
89
90
def train(train_loader, model, optimizer, criterion, epoch, summ_writer):
91
    losses = utils.AverageMeter()
92
    cur_step = epoch*len(train_loader)
93
    cur_lr = optimizer.param_groups[0]['lr']
94
    logger.info("Epoch {} LR {}".format(epoch, cur_lr))
95
    writer.add_scalar('train/lr', cur_lr, cur_step)
96
    model.train()
97
    #all_dice = np.empty().astype(np.float32)
98
    all_dice = []
99
    for step, (name, X, y) in enumerate(train_loader):
100
        X, y = torch.from_numpy(X).to(device, non_blocking=True), torch.from_numpy(y).to(device, non_blocking=True)
101
        N = X.size(0)
102
103
        optimizer.zero_grad()
104
        logits = model(X)
105
106
        loss = criterion(logits, y)
107
        #loss.backward()
108
        with amp.scale_loss(loss, optimizer) as scaled_loss:
109
            scaled_loss.backward()
110
        # gradient clipping
111
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
112
        optimizer.step()
113
114
        losses.update(loss.item(), N)
115
116
        if step % config.print_freq == 0 or step == len(train_loader)-1:
117
           logger.info(
118
                "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {:.3f} ".format(
119
                    epoch+1, config.epochs, step, len(train_loader), losses.avg,
120
                    ))
121
122
        writer.add_scalar('train/loss', loss.item(), cur_step)
123
124
        logits[logits >= 0.5] = 1
125
        logits[logits < 0.5] = 0
126
        predict = logits.cpu().detach().numpy()
127
        y = y.cpu().detach().numpy()
128
        dice_i = utils.evaluate(predict, y)
129
        all_dice.append(dice_i)
130
        cur_step += 1
131
    dice_mean = 0
132
133
    for i in all_dice:
134
        dice_mean += i/len(all_dice)
135
136
    train_avg_loss = losses.avg
137
    train_avg_dice = dice_mean
138
    loss_scalers = {'train': train_avg_loss}
139
    summ_writer.add_scalars('loss', loss_scalers, epoch + 1)
140
141
    dice_scalers = {'train': train_avg_dice}
142
    summ_writer.add_scalars('avg_dice', dice_scalers, epoch + 1)
143
144
    if (epoch+1) % 50 == 0:
145
        chpt_prefx = config.training_checkpoint_prefix
146
        save_dict = {'epoch': epoch + 1,
147
                     'model_state_dict': model.state_dict(),
148
                     'optimizer_state_dict': optimizer.state_dict(),
149
                     'amp': amp.state_dict()}
150
        save_name = "{0:}_{1:}.pt".format(chpt_prefx, epoch + 1)
151
        torch.save(save_dict, save_name)
152
    print("train_avg_loss", train_avg_loss)
153
    print("train_avg_dice", train_avg_dice)
154
155
def validate(valid_loader, model, criterion, epoch, summ_writer, best_dice):
156
    losses = utils.AverageMeter()
157
158
    model.eval()
159
    all_dice = np.zeros([len(valid_loader)]).astype(np.float32)
160
    all_dice = []
161
    totel_time = 0
162
    start_time = time.time()
163
    size_z = 48
164
    with torch.no_grad():
165
        for i, (name, image, label) in enumerate(valid_loader):
166
            image = torch.from_numpy(image)
167
            predict = np.zeros(shape=label.shape, dtype=label.dtype)
168
            z = image.shape[4]
169
            m = z // size_z if z % size_z == 0 else z // size_z + 1
170
            start_time = time.time()
171
            for k in range(m):
172
                if (k+1)*size_z <= z:
173
                    max_z = (k+1)*size_z
174
                else:
175
                    max_z = z
176
                min_z = max_z - size_z
177
                image_k = image[:, :, :, :, min_z:max_z].float().to(device, non_blocking=True)
178
                predict_k = model(image_k)
179
                predict_k[predict_k >= 0.5] = 1
180
                predict_k[predict_k < 0.5] = 0
181
                predict[:, :, :, :, min_z:max_z] = predict_k.cpu().detach().numpy()
182
            totel_time = totel_time + time.time() - start_time
183
            all_dice.append(utils.evaluate(predict, label))
184
185
    dice_len = len(all_dice)
186
    dice_np = np.empty(shape=[dice_len])
187
    #list_image = []
188
    for i in range(dice_len):
189
        dice_np[i] = all_dice[i]
190
        logger.info("{}  dice: {:.4%} ".format(i, all_dice[i]))
191
    logger.info("mean: {}".format(dice_np.mean()))
192
    logger.info("std : {}".format(dice_np.std()))
193
194
    if best_dice < dice_np.mean():
195
        chpt_prefx = config.validing_checkpoint_prefix
196
        save_dict = {'epoch': epoch + 1,
197
                     'model_state_dict': model.state_dict(),
198
                     'amp': amp.state_dict()}
199
        fname = "{}/best.pt".format(chpt_prefx)
200
        if os.path.isfile(fname):
201
            os.remove(fname)
202
        save_name = "{}/best.pt".format(chpt_prefx)
203
        torch.save(save_dict, save_name)
204
205
    dice_scalers = {'vadil': dice_np.mean()}
206
    summ_writer.add_scalars('vadil_avg_dice', dice_scalers, epoch + 1)
207
208
    avg_time = totel_time / dice_len
209
    logger.info("average testing time : {}".format(avg_time))
210
211
    mean_dice = np.mean(all_dice, axis = 0)
212
    writer.add_scalar('val/dice', mean_dice, epoch)
213
    writer.add_scalar('val/loss', losses.avg, epoch)
214
    logger.info("Valid: [{:2d}/{}] average dice: {:.4%} ".format(epoch+1, config.epochs, mean_dice))
215
216
    return mean_dice
217
218
219
220
def save_nd_array_as_image(data, image_name, reference_name = None):
221
    """
222
    save a 3D or 2D numpy array as medical image or RGB image
223
    inputs:
224
        data: a numpy array with shape [D, H, W] or [C, H, W]
225
        image_name: the output file name
226
    outputs: None
227
    """
228
    data_dim = len(data.shape)
229
    assert(data_dim == 2 or data_dim == 3)
230
    if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
231
        image_name.endswith(".mha")):
232
        assert(data_dim == 3)
233
        save_array_as_nifty_volume(data, image_name, reference_name)
234
235
def save_array_as_nifty_volume(data, image_name, reference_name = None):
236
    """
237
    save a numpy array as nifty image
238
    inputs:
239
        data: a numpy array with shape [Depth, Height, Width]
240
        image_name: the ouput file name
241
        reference_name: file name of the reference image of which affine and header are used
242
    outputs: None
243
    """
244
    img = sitk.GetImageFromArray(data)
245
    if(reference_name is not None):
246
        img_ref = sitk.ReadImage(reference_name)
247
        img.CopyInformation(img_ref)
248
    sitk.WriteImage(img, image_name)
249
250
251
if __name__ == "__main__":
252
    main()