|
a |
|
b/trainer.py |
|
|
1 |
import csv |
|
|
2 |
import copy |
|
|
3 |
import time |
|
|
4 |
from tqdm import tqdm |
|
|
5 |
import torch |
|
|
6 |
import numpy as np |
|
|
7 |
import os |
|
|
8 |
from datetime import datetime |
|
|
9 |
import pathlib |
|
|
10 |
import matplotlib.pyplot as plt |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def load_checkpoint(bpath): |
|
|
14 |
|
|
|
15 |
checkpoint_folder = os.path.join(bpath, 'checkpoint') |
|
|
16 |
checkpoint_filename = os.path.join( |
|
|
17 |
checkpoint_folder, 'checkpoint.pth.tar') |
|
|
18 |
|
|
|
19 |
bestweights_filename = os.path.join( |
|
|
20 |
checkpoint_folder, 'best_weights_checkpoint.pth.tar') |
|
|
21 |
|
|
|
22 |
file = pathlib.Path(checkpoint_filename) |
|
|
23 |
|
|
|
24 |
if not file.exists(): |
|
|
25 |
return None, None, None, None, None, None |
|
|
26 |
|
|
|
27 |
file = pathlib.Path(bestweights_filename) |
|
|
28 |
|
|
|
29 |
best_weight = None |
|
|
30 |
if file.exists(): |
|
|
31 |
best_weight = torch.load(bestweights_filename) |
|
|
32 |
best_weight = best_weight['state_dict'] |
|
|
33 |
|
|
|
34 |
checkpoint = torch.load(checkpoint_filename) |
|
|
35 |
|
|
|
36 |
return checkpoint['epoch'], checkpoint['state_dict'], best_weight, checkpoint['optimizer'], checkpoint['best_loss'], checkpoint['best_pred'] |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def save_checkpoint(bpath, state, is_best=False): |
|
|
40 |
|
|
|
41 |
checkpoint_folder = os.path.join(bpath, 'checkpoint') |
|
|
42 |
|
|
|
43 |
if is_best: |
|
|
44 |
best_pred = state['best_pred'] |
|
|
45 |
with open(os.path.join(checkpoint_folder, 'best_pred.txt'), 'w') as f: |
|
|
46 |
f.write(str(best_pred)) |
|
|
47 |
|
|
|
48 |
best_pred = state['best_loss'] |
|
|
49 |
with open(os.path.join(checkpoint_folder, 'best_loss.txt'), 'w') as f: |
|
|
50 |
f.write(str(best_pred)) |
|
|
51 |
|
|
|
52 |
torch.save(state, os.path.join(checkpoint_folder, |
|
|
53 |
'best_weights_checkpoint.pth.tar')) |
|
|
54 |
|
|
|
55 |
torch.save(state, os.path.join(checkpoint_folder, |
|
|
56 |
'checkpoint.pth.tar')) |
|
|
57 |
|
|
|
58 |
|
|
|
59 |
def train_model(model, criterion, dataloaders, optimizer, scheduler, metrics, bpath, num_epochs=3): |
|
|
60 |
|
|
|
61 |
start_epoch, state_dict, bweights, optm, bloss, bpred = load_checkpoint( |
|
|
62 |
bpath) |
|
|
63 |
|
|
|
64 |
if start_epoch is not None: |
|
|
65 |
print("") |
|
|
66 |
print("NEW CHECKPOINT FOUND! LAST EPOCH ", start_epoch) |
|
|
67 |
print("") |
|
|
68 |
model.load_state_dict(state_dict) |
|
|
69 |
start_epoch += 1 |
|
|
70 |
|
|
|
71 |
best_model_wts = copy.deepcopy(bweights) |
|
|
72 |
best_loss = float(bloss) |
|
|
73 |
|
|
|
74 |
best_Train_dice = 1e-5 |
|
|
75 |
best_Valid_dice = bpred |
|
|
76 |
else: |
|
|
77 |
start_epoch = 1 |
|
|
78 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
79 |
best_loss = 1e10 |
|
|
80 |
|
|
|
81 |
best_Train_dice = 1e-5 |
|
|
82 |
best_Valid_dice = 1e-5 |
|
|
83 |
|
|
|
84 |
since = time.time() |
|
|
85 |
|
|
|
86 |
# Use gpu if available |
|
|
87 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
88 |
model.to(device) |
|
|
89 |
# Initialize the log file for training and testing loss and metrics |
|
|
90 |
fieldnames = ['epoch', 'Train_loss', 'Valid_loss'] + \ |
|
|
91 |
[f'Train_{m}' for m in metrics.keys()] + \ |
|
|
92 |
[f'Valid_{m}' for m in metrics.keys()] |
|
|
93 |
with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile: |
|
|
94 |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
|
|
95 |
writer.writeheader() |
|
|
96 |
|
|
|
97 |
for epoch in range(start_epoch, num_epochs+1): |
|
|
98 |
print('Epoch {}/{}'.format(epoch, num_epochs)) |
|
|
99 |
print('-' * 10) |
|
|
100 |
# Each epoch has a training and validation phase |
|
|
101 |
# Initialize batch summary |
|
|
102 |
batchsummary = {a: [0] for a in fieldnames} |
|
|
103 |
|
|
|
104 |
for phase in ['Train', 'Valid']: |
|
|
105 |
if phase == 'Train': |
|
|
106 |
model.train() # Set model to training mode |
|
|
107 |
else: |
|
|
108 |
model.eval() # Set model to evaluate mode |
|
|
109 |
|
|
|
110 |
# Iterate over data. |
|
|
111 |
|
|
|
112 |
for sample in tqdm(iter(dataloaders[phase])): |
|
|
113 |
|
|
|
114 |
inputs = sample['image'].to(device) |
|
|
115 |
masks = sample['mask'].to(device) |
|
|
116 |
|
|
|
117 |
# zero the parameter gradients |
|
|
118 |
optimizer.zero_grad() |
|
|
119 |
|
|
|
120 |
# track history if only in train |
|
|
121 |
with torch.set_grad_enabled(phase == 'Train'): |
|
|
122 |
outputs = model(inputs) |
|
|
123 |
# loss = criterion(outputs['out'], masks) |
|
|
124 |
loss = criterion(outputs, masks) |
|
|
125 |
|
|
|
126 |
# y_pred = outputs['out'].data.cpu().numpy().squeeze(1) |
|
|
127 |
y_pred = outputs.data.cpu().numpy().squeeze(1) |
|
|
128 |
y_true = masks.data.cpu().numpy().squeeze(1) |
|
|
129 |
|
|
|
130 |
for name, metric in metrics.items(): |
|
|
131 |
if name == 'dice' or name == 'dice_target': |
|
|
132 |
# Use a classification threshold of 0.5 |
|
|
133 |
val_metric = metric(y_pred > 0.5, y_true > 0) |
|
|
134 |
|
|
|
135 |
if val_metric is not None: |
|
|
136 |
batchsummary[f'{phase}_{name}'].append( |
|
|
137 |
val_metric) |
|
|
138 |
|
|
|
139 |
# backward + optimize only if in training phase |
|
|
140 |
if phase == 'Train': |
|
|
141 |
loss.backward() |
|
|
142 |
optimizer.step() |
|
|
143 |
|
|
|
144 |
batchsummary['epoch'] = epoch |
|
|
145 |
epoch_loss = loss |
|
|
146 |
batchsummary[f'{phase}_loss'] = epoch_loss.item() |
|
|
147 |
print('{} Loss: {:.4f}'.format(phase, loss)) |
|
|
148 |
|
|
|
149 |
print('New LR: ', scheduler.get_last_lr()) |
|
|
150 |
scheduler.step() |
|
|
151 |
|
|
|
152 |
for field in fieldnames[3:]: |
|
|
153 |
batchsummary[field] = np.mean(batchsummary[field]) |
|
|
154 |
|
|
|
155 |
print(batchsummary) |
|
|
156 |
|
|
|
157 |
epoch_valid_dice = np.mean(batchsummary['Valid_dice_tumor']) |
|
|
158 |
is_best = False |
|
|
159 |
with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile: |
|
|
160 |
writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
|
|
161 |
writer.writerow(batchsummary) |
|
|
162 |
|
|
|
163 |
SAVE_BESTLOSS_WEIGTH = False |
|
|
164 |
if SAVE_BESTLOSS_WEIGTH: |
|
|
165 |
# deep copy the model |
|
|
166 |
if phase == 'Valid' and loss < best_loss: |
|
|
167 |
print('\nnew best loss: {:.4f} in epoch {}\n'.format( |
|
|
168 |
loss, epoch)) |
|
|
169 |
best_loss = loss |
|
|
170 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
171 |
now = datetime.now() |
|
|
172 |
str_datetime = now.strftime("%Y%m%d_%H_%M_%S") |
|
|
173 |
|
|
|
174 |
best_Train_dice = np.mean(batchsummary['Train_dice']) |
|
|
175 |
best_Valid_dice = np.mean(batchsummary['Valid_dice']) |
|
|
176 |
|
|
|
177 |
torch.save(model, os.path.join( |
|
|
178 |
bpath, 'weights_partial_epch{}_{}.pt'.format(epoch, str_datetime))) |
|
|
179 |
else: |
|
|
180 |
# deep copy the model |
|
|
181 |
if phase == 'Valid' and epoch_valid_dice > best_Valid_dice: |
|
|
182 |
is_best = True |
|
|
183 |
print('\nNew valid dice: {:.4f} in epoch {}\n'.format( |
|
|
184 |
epoch_valid_dice, epoch)) |
|
|
185 |
best_loss = loss.item() |
|
|
186 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
187 |
now = datetime.now() |
|
|
188 |
str_datetime = now.strftime("%Y%m%d_%H_%M_%S") |
|
|
189 |
|
|
|
190 |
best_Train_dice = np.mean(batchsummary['Train_dice']) |
|
|
191 |
best_Valid_dice = epoch_valid_dice |
|
|
192 |
|
|
|
193 |
torch.save(model, os.path.join( |
|
|
194 |
bpath, 'weights_partial_diceval_epch{}_{}.pt'.format(epoch, str_datetime))) |
|
|
195 |
|
|
|
196 |
# torch.save(model, os.path.join( |
|
|
197 |
# bpath, 'model_weights_partial.pt')) |
|
|
198 |
|
|
|
199 |
save_checkpoint(bpath, { |
|
|
200 |
'epoch': epoch, |
|
|
201 |
'state_dict': model.state_dict(), |
|
|
202 |
'optimizer': optimizer.state_dict(), |
|
|
203 |
'best_pred': best_Valid_dice, |
|
|
204 |
'best_loss': best_loss |
|
|
205 |
}, is_best=is_best) |
|
|
206 |
|
|
|
207 |
time_elapsed = time.time() - since |
|
|
208 |
print('Training complete in {:.0f}m {:.0f}s'.format( |
|
|
209 |
time_elapsed // 60, time_elapsed % 60)) |
|
|
210 |
print('Lowest by valid dice Loss: {:4f}'.format(best_loss)) |
|
|
211 |
print('Max valid Dice: {:4f}'.format(best_Valid_dice)) |
|
|
212 |
|
|
|
213 |
# load best model weights |
|
|
214 |
model.load_state_dict(best_model_wts) |
|
|
215 |
|
|
|
216 |
return best_Train_dice, best_Valid_dice |