Diff of /trainer.py [000000] .. [72db80]

Switch to unified view

a b/trainer.py
1
import csv
2
import copy
3
import time
4
from tqdm import tqdm
5
import torch
6
import numpy as np
7
import os
8
from datetime import datetime
9
import pathlib
10
import matplotlib.pyplot as plt
11
12
13
def load_checkpoint(bpath):
14
15
    checkpoint_folder = os.path.join(bpath, 'checkpoint')
16
    checkpoint_filename = os.path.join(
17
        checkpoint_folder, 'checkpoint.pth.tar')
18
19
    bestweights_filename = os.path.join(
20
        checkpoint_folder, 'best_weights_checkpoint.pth.tar')
21
22
    file = pathlib.Path(checkpoint_filename)
23
24
    if not file.exists():
25
        return None, None, None, None, None, None
26
27
    file = pathlib.Path(bestweights_filename)
28
29
    best_weight = None
30
    if file.exists():
31
        best_weight = torch.load(bestweights_filename)
32
        best_weight = best_weight['state_dict']
33
34
    checkpoint = torch.load(checkpoint_filename)
35
36
    return checkpoint['epoch'], checkpoint['state_dict'], best_weight, checkpoint['optimizer'], checkpoint['best_loss'], checkpoint['best_pred']
37
38
39
def save_checkpoint(bpath, state, is_best=False):
40
41
    checkpoint_folder = os.path.join(bpath, 'checkpoint')
42
43
    if is_best:
44
        best_pred = state['best_pred']
45
        with open(os.path.join(checkpoint_folder, 'best_pred.txt'), 'w') as f:
46
            f.write(str(best_pred))
47
48
        best_pred = state['best_loss']
49
        with open(os.path.join(checkpoint_folder, 'best_loss.txt'), 'w') as f:
50
            f.write(str(best_pred))
51
52
        torch.save(state, os.path.join(checkpoint_folder,
53
                                       'best_weights_checkpoint.pth.tar'))
54
55
    torch.save(state, os.path.join(checkpoint_folder,
56
                                   'checkpoint.pth.tar'))
57
58
59
def train_model(model, criterion, dataloaders, optimizer, scheduler, metrics, bpath, num_epochs=3):
60
61
    start_epoch, state_dict, bweights, optm, bloss, bpred = load_checkpoint(
62
        bpath)
63
64
    if start_epoch is not None:
65
        print("")
66
        print("NEW CHECKPOINT FOUND! LAST EPOCH ", start_epoch)
67
        print("")
68
        model.load_state_dict(state_dict)
69
        start_epoch += 1
70
71
        best_model_wts = copy.deepcopy(bweights)
72
        best_loss = float(bloss)
73
74
        best_Train_dice = 1e-5
75
        best_Valid_dice = bpred
76
    else:
77
        start_epoch = 1
78
        best_model_wts = copy.deepcopy(model.state_dict())
79
        best_loss = 1e10
80
81
        best_Train_dice = 1e-5
82
        best_Valid_dice = 1e-5
83
84
    since = time.time()
85
86
    # Use gpu if available
87
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
88
    model.to(device)
89
    # Initialize the log file for training and testing loss and metrics
90
    fieldnames = ['epoch', 'Train_loss', 'Valid_loss'] + \
91
        [f'Train_{m}' for m in metrics.keys()] + \
92
        [f'Valid_{m}' for m in metrics.keys()]
93
    with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
94
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
95
        writer.writeheader()
96
97
    for epoch in range(start_epoch, num_epochs+1):
98
        print('Epoch {}/{}'.format(epoch, num_epochs))
99
        print('-' * 10)
100
        # Each epoch has a training and validation phase
101
        # Initialize batch summary
102
        batchsummary = {a: [0] for a in fieldnames}
103
104
        for phase in ['Train', 'Valid']:
105
            if phase == 'Train':
106
                model.train()  # Set model to training mode
107
            else:
108
                model.eval()   # Set model to evaluate mode
109
110
            # Iterate over data.
111
112
            for sample in tqdm(iter(dataloaders[phase])):
113
114
                inputs = sample['image'].to(device)
115
                masks = sample['mask'].to(device)
116
117
                # zero the parameter gradients
118
                optimizer.zero_grad()
119
120
                # track history if only in train
121
                with torch.set_grad_enabled(phase == 'Train'):
122
                    outputs = model(inputs)
123
                    # loss = criterion(outputs['out'], masks)
124
                    loss = criterion(outputs, masks)
125
126
                    # y_pred = outputs['out'].data.cpu().numpy().squeeze(1)
127
                    y_pred = outputs.data.cpu().numpy().squeeze(1)
128
                    y_true = masks.data.cpu().numpy().squeeze(1)
129
130
                    for name, metric in metrics.items():
131
                        if name == 'dice' or name == 'dice_target':
132
                            # Use a classification threshold of 0.5
133
                            val_metric = metric(y_pred > 0.5, y_true > 0)
134
135
                            if val_metric is not None:
136
                                batchsummary[f'{phase}_{name}'].append(
137
                                    val_metric)
138
139
                    # backward + optimize only if in training phase
140
                    if phase == 'Train':
141
                        loss.backward()
142
                        optimizer.step()
143
144
            batchsummary['epoch'] = epoch
145
            epoch_loss = loss
146
            batchsummary[f'{phase}_loss'] = epoch_loss.item()
147
            print('{} Loss: {:.4f}'.format(phase, loss))
148
149
        print('New LR: ', scheduler.get_last_lr())
150
        scheduler.step()
151
152
        for field in fieldnames[3:]:
153
            batchsummary[field] = np.mean(batchsummary[field])
154
155
        print(batchsummary)
156
157
        epoch_valid_dice = np.mean(batchsummary['Valid_dice_tumor'])
158
        is_best = False
159
        with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
160
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
161
            writer.writerow(batchsummary)
162
163
            SAVE_BESTLOSS_WEIGTH = False
164
            if SAVE_BESTLOSS_WEIGTH:
165
                # deep copy the model
166
                if phase == 'Valid' and loss < best_loss:
167
                    print('\nnew best loss: {:.4f} in epoch {}\n'.format(
168
                        loss, epoch))
169
                    best_loss = loss
170
                    best_model_wts = copy.deepcopy(model.state_dict())
171
                    now = datetime.now()
172
                    str_datetime = now.strftime("%Y%m%d_%H_%M_%S")
173
174
                    best_Train_dice = np.mean(batchsummary['Train_dice'])
175
                    best_Valid_dice = np.mean(batchsummary['Valid_dice'])
176
177
                    torch.save(model, os.path.join(
178
                        bpath, 'weights_partial_epch{}_{}.pt'.format(epoch, str_datetime)))
179
            else:
180
                # deep copy the model
181
                if phase == 'Valid' and epoch_valid_dice > best_Valid_dice:
182
                    is_best = True
183
                    print('\nNew valid dice: {:.4f} in epoch {}\n'.format(
184
                        epoch_valid_dice, epoch))
185
                    best_loss = loss.item()
186
                    best_model_wts = copy.deepcopy(model.state_dict())
187
                    now = datetime.now()
188
                    str_datetime = now.strftime("%Y%m%d_%H_%M_%S")
189
190
                    best_Train_dice = np.mean(batchsummary['Train_dice'])
191
                    best_Valid_dice = epoch_valid_dice
192
193
                    torch.save(model, os.path.join(
194
                        bpath, 'weights_partial_diceval_epch{}_{}.pt'.format(epoch, str_datetime)))
195
196
                    # torch.save(model, os.path.join(
197
                    #     bpath, 'model_weights_partial.pt'))
198
199
            save_checkpoint(bpath, {
200
                'epoch': epoch,
201
                'state_dict': model.state_dict(),
202
                'optimizer': optimizer.state_dict(),
203
                'best_pred': best_Valid_dice,
204
                'best_loss': best_loss
205
            }, is_best=is_best)
206
207
    time_elapsed = time.time() - since
208
    print('Training complete in {:.0f}m {:.0f}s'.format(
209
        time_elapsed // 60, time_elapsed % 60))
210
    print('Lowest by valid dice Loss: {:4f}'.format(best_loss))
211
    print('Max valid Dice: {:4f}'.format(best_Valid_dice))
212
213
    # load best model weights
214
    model.load_state_dict(best_model_wts)
215
216
    return best_Train_dice, best_Valid_dice