|
a |
|
b/pathflowai/models.py |
|
|
1 |
""" |
|
|
2 |
models.py |
|
|
3 |
======================= |
|
|
4 |
Houses all of the PyTorch models to access and the corresponding Scikit-Learn like model trainer. |
|
|
5 |
""" |
|
|
6 |
from pathflowai.unet import UNet |
|
|
7 |
# from pathflowai.unet2 import NestedUNet |
|
|
8 |
# from pathflowai.unet4 import UNetSmall as UNet2 |
|
|
9 |
from pathflowai.fast_scnn import get_fast_scnn |
|
|
10 |
import torch |
|
|
11 |
import torchvision |
|
|
12 |
from torchvision import models |
|
|
13 |
from torchvision.models import segmentation as segmodels |
|
|
14 |
from torch import nn |
|
|
15 |
from torch.nn import functional as F |
|
|
16 |
import pandas as pd, numpy as np |
|
|
17 |
import matplotlib |
|
|
18 |
matplotlib.use('Agg') |
|
|
19 |
import matplotlib.pyplot as plt |
|
|
20 |
import seaborn as sns |
|
|
21 |
from pathflowai.schedulers import * |
|
|
22 |
import pysnooper |
|
|
23 |
from torch.autograd import Variable |
|
|
24 |
import copy |
|
|
25 |
from sklearn.metrics import roc_curve, confusion_matrix, classification_report, r2_score |
|
|
26 |
sns.set() |
|
|
27 |
from pathflowai.losses import GeneralizedDiceLoss, FocalLoss |
|
|
28 |
from apex import amp |
|
|
29 |
from torch.nn import functional as F |
|
|
30 |
import time, os |
|
|
31 |
|
|
|
32 |
class MLP(nn.Module): |
|
|
33 |
"""Multi-layer perceptron model. |
|
|
34 |
|
|
|
35 |
Parameters |
|
|
36 |
---------- |
|
|
37 |
n_input:int |
|
|
38 |
Number input dimensions. |
|
|
39 |
hidden_topology:list |
|
|
40 |
List of hidden topology |
|
|
41 |
dropout_p:float |
|
|
42 |
Amount dropout. |
|
|
43 |
n_outputs:int |
|
|
44 |
Number outputs. |
|
|
45 |
binary:bool |
|
|
46 |
Binary output with sigmoid transform. |
|
|
47 |
softmax:bool |
|
|
48 |
Whether to apply softmax on output. |
|
|
49 |
|
|
|
50 |
""" |
|
|
51 |
def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True, softmax=False): |
|
|
52 |
super(MLP,self).__init__() |
|
|
53 |
self.topology = [n_input]+hidden_topology+[n_outputs] |
|
|
54 |
layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)] |
|
|
55 |
for layer in layers: |
|
|
56 |
torch.nn.init.xavier_uniform_(layer.weight) |
|
|
57 |
self.layers = [nn.Sequential(layer,nn.LeakyReLU(),nn.Dropout(p=dropout_p)) for layer in layers] |
|
|
58 |
self.output_layer = nn.Linear(self.topology[-2],self.topology[-1]) |
|
|
59 |
torch.nn.init.xavier_uniform_(self.output_layer.weight) |
|
|
60 |
if binary: |
|
|
61 |
output_transform = nn.Sigmoid() |
|
|
62 |
elif softmax: |
|
|
63 |
output_transform = nn.Softmax() |
|
|
64 |
else: |
|
|
65 |
output_transform = nn.Dropout(p=0.) |
|
|
66 |
self.layers.append(nn.Sequential(self.output_layer,output_transform)) |
|
|
67 |
self.mlp = nn.Sequential(*self.layers) |
|
|
68 |
|
|
|
69 |
def forward(self,x): |
|
|
70 |
return self.mlp(x) |
|
|
71 |
|
|
|
72 |
class FixedSegmentationModule(nn.Module): |
|
|
73 |
"""Special model modification for segmentation tasks. Gets output from some of the models' forward loops. |
|
|
74 |
|
|
|
75 |
Parameters |
|
|
76 |
---------- |
|
|
77 |
segnet:nn.Module |
|
|
78 |
Segmentation network |
|
|
79 |
""" |
|
|
80 |
def __init__(self, segnet): |
|
|
81 |
super(FixedSegmentationModule, self).__init__() |
|
|
82 |
self.segnet=segnet |
|
|
83 |
|
|
|
84 |
def forward(self, x): |
|
|
85 |
"""Forward pass. |
|
|
86 |
|
|
|
87 |
Parameters |
|
|
88 |
---------- |
|
|
89 |
x:Tensor |
|
|
90 |
Input |
|
|
91 |
|
|
|
92 |
Returns |
|
|
93 |
------- |
|
|
94 |
Tensor |
|
|
95 |
Output from model. |
|
|
96 |
|
|
|
97 |
""" |
|
|
98 |
return self.segnet(x)['out'] |
|
|
99 |
|
|
|
100 |
def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden=100, segmentation=False): |
|
|
101 |
"""Generate a nn.Module for use. |
|
|
102 |
|
|
|
103 |
Parameters |
|
|
104 |
---------- |
|
|
105 |
pretrain:bool |
|
|
106 |
Pretrain using ImageNet? |
|
|
107 |
architecture:str |
|
|
108 |
See model_training for list of all architectures you can train with. |
|
|
109 |
num_classes:int |
|
|
110 |
Number of classes to predict. |
|
|
111 |
add_sigmoid:type |
|
|
112 |
Add sigmoid non-linearity at end. |
|
|
113 |
n_hidden:int |
|
|
114 |
Number of hidden fully connected layers. |
|
|
115 |
segmentation:bool |
|
|
116 |
Whether segment task? |
|
|
117 |
|
|
|
118 |
Returns |
|
|
119 |
------- |
|
|
120 |
nn.Module |
|
|
121 |
Pytorch model. |
|
|
122 |
|
|
|
123 |
""" |
|
|
124 |
# to add: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/models/model_zoo.py |
|
|
125 |
#architecture = 'resnet' + str(num_layers) |
|
|
126 |
model = None |
|
|
127 |
|
|
|
128 |
if architecture =='unet': |
|
|
129 |
model = UNet(n_channels=3, n_classes=num_classes) |
|
|
130 |
elif architecture =='unet2': |
|
|
131 |
print('Deprecated for now, defaulting to UNET.') |
|
|
132 |
model = UNet(n_channels=3, n_classes=num_classes)#UNet2(3,num_classes) |
|
|
133 |
elif architecture == 'fast_scnn': |
|
|
134 |
model = get_fast_scnn(num_classes) |
|
|
135 |
elif architecture == 'nested_unet': |
|
|
136 |
print('Nested UNET is deprecated for now, defaulting to UNET.') |
|
|
137 |
model = UNet(n_channels=3, n_classes=num_classes)#NestedUNet(3, num_classes) |
|
|
138 |
elif architecture.startswith('efficientnet'): |
|
|
139 |
from efficientnet_pytorch import EfficientNet |
|
|
140 |
if pretrain: |
|
|
141 |
model = EfficientNet.from_pretrained(architecture, override_params=dict(num_classes=num_classes)) |
|
|
142 |
else: |
|
|
143 |
model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes)) |
|
|
144 |
print(model) |
|
|
145 |
elif architecture.startswith('sqnxt'): |
|
|
146 |
from pytorchcv.model_provider import get_model as ptcv_get_model |
|
|
147 |
model = ptcv_get_model(architecture, pretrained=pretrain) |
|
|
148 |
num_ftrs=int(128*int(architecture.split('_')[-1][1])) |
|
|
149 |
model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp |
|
|
150 |
else: |
|
|
151 |
#for pretrained on imagenet |
|
|
152 |
model_names = [m for m in dir(models) if not m.startswith('__')] |
|
|
153 |
segmentation_model_names = [m for m in dir(segmodels) if not m.startswith('__')] |
|
|
154 |
if architecture in model_names: |
|
|
155 |
model = getattr(models, architecture)(pretrained=pretrain) |
|
|
156 |
if segmentation: |
|
|
157 |
if architecture in segmentation_model_names: |
|
|
158 |
model = getattr(segmodels, architecture)(pretrained=pretrain) |
|
|
159 |
else: |
|
|
160 |
model = UNet(n_channels=3, n_classes=num_classes) |
|
|
161 |
if architecture.startswith('deeplab'): |
|
|
162 |
model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1)) |
|
|
163 |
model = FixedSegmentationModule(model) |
|
|
164 |
elif architecture.startswith('fcn'): |
|
|
165 |
model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) |
|
|
166 |
model = FixedSegmentationModule(model) |
|
|
167 |
elif architecture.startswith('resnet') or architecture.startswith('inception'): |
|
|
168 |
num_ftrs = model.fc.in_features |
|
|
169 |
#linear_layer = nn.Linear(num_ftrs, num_classes) |
|
|
170 |
#torch.nn.init.xavier_uniform(linear_layer.weight) |
|
|
171 |
model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else []))) |
|
|
172 |
elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'): |
|
|
173 |
num_ftrs = model.classifier[6].in_features |
|
|
174 |
#linear_layer = nn.Linear(num_ftrs, num_classes) |
|
|
175 |
#torch.nn.init.xavier_uniform(linear_layer.weight) |
|
|
176 |
model.classifier[6] = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else []))) |
|
|
177 |
return model |
|
|
178 |
|
|
|
179 |
#@pysnooper.snoop("dice_loss.log") |
|
|
180 |
def dice_loss(logits, true, eps=1e-7): |
|
|
181 |
"""https://github.com/kevinzakka/pytorch-goodies |
|
|
182 |
Computes the Sørensen–Dice loss. |
|
|
183 |
|
|
|
184 |
Note that PyTorch optimizers minimize a loss. In this |
|
|
185 |
case, we would like to maximize the dice loss so we |
|
|
186 |
return the negated dice loss. |
|
|
187 |
|
|
|
188 |
Args: |
|
|
189 |
true: a tensor of shape [B, 1, H, W]. |
|
|
190 |
logits: a tensor of shape [B, C, H, W]. Corresponds to |
|
|
191 |
the raw output or logits of the model. |
|
|
192 |
eps: added to the denominator for numerical stability. |
|
|
193 |
|
|
|
194 |
Returns: |
|
|
195 |
dice_loss: the Sørensen–Dice loss. |
|
|
196 |
""" |
|
|
197 |
#true=true.long() |
|
|
198 |
num_classes = logits.shape[1] |
|
|
199 |
if num_classes == 1: |
|
|
200 |
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] |
|
|
201 |
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() |
|
|
202 |
true_1_hot_f = true_1_hot[:, 0:1, :, :] |
|
|
203 |
true_1_hot_s = true_1_hot[:, 1:2, :, :] |
|
|
204 |
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) |
|
|
205 |
pos_prob = torch.sigmoid(logits) |
|
|
206 |
neg_prob = 1 - pos_prob |
|
|
207 |
probas = torch.cat([pos_prob, neg_prob], dim=1) |
|
|
208 |
else: |
|
|
209 |
true_1_hot = torch.eye(num_classes)[true.squeeze(1)] |
|
|
210 |
#print(true_1_hot.size()) |
|
|
211 |
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() |
|
|
212 |
probas = F.softmax(logits, dim=1) |
|
|
213 |
true_1_hot = true_1_hot.type(logits.type()) |
|
|
214 |
dims = (0,) + tuple(range(2, true.ndimension())) |
|
|
215 |
intersection = torch.sum(probas * true_1_hot, dims) |
|
|
216 |
cardinality = torch.sum(probas + true_1_hot, dims) |
|
|
217 |
dice_loss = (2. * intersection / (cardinality + eps)).mean() |
|
|
218 |
return (1 - dice_loss) |
|
|
219 |
|
|
|
220 |
class ModelTrainer: |
|
|
221 |
"""Trainer for the neural network model that wraps it into a scikit-learn like interface. |
|
|
222 |
|
|
|
223 |
Parameters |
|
|
224 |
---------- |
|
|
225 |
model:nn.Module |
|
|
226 |
Deep learning pytorch model. |
|
|
227 |
n_epoch:int |
|
|
228 |
Number training epochs. |
|
|
229 |
validation_dataloader:DataLoader |
|
|
230 |
Dataloader of validation dataset. |
|
|
231 |
optimizer_opts:dict |
|
|
232 |
Options for optimizer. |
|
|
233 |
scheduler_opts:dict |
|
|
234 |
Options for learning rate scheduler. |
|
|
235 |
loss_fn:str |
|
|
236 |
String to call a particular loss function for model. |
|
|
237 |
reduction:str |
|
|
238 |
Mean or sum reduction of loss. |
|
|
239 |
num_train_batches:int |
|
|
240 |
Number of training batches for epoch. |
|
|
241 |
""" |
|
|
242 |
def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1, apex_opt_level="O2", checkpointing=False): |
|
|
243 |
|
|
|
244 |
self.model = model |
|
|
245 |
optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD} |
|
|
246 |
loss_functions = {'bce':nn.BCEWithLogitsLoss(reduction=reduction), 'ce':nn.CrossEntropyLoss(reduction=reduction), 'mse':nn.MSELoss(reduction=reduction), 'nll':nn.NLLLoss(reduction=reduction), 'dice':dice_loss, 'focal':FocalLoss(num_class=2), 'gdl':GeneralizedDiceLoss(add_softmax=True)} |
|
|
247 |
loss_functions['dice+ce']=(lambda y_pred, y_true: dice_loss(y_pred,y_true)+loss_functions['ce'](y_pred,y_true)) |
|
|
248 |
if 'name' not in list(optimizer_opts.keys()): |
|
|
249 |
optimizer_opts['name']='adam' |
|
|
250 |
self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts) |
|
|
251 |
if torch.cuda.is_available(): |
|
|
252 |
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=apex_opt_level) |
|
|
253 |
self.cuda=True |
|
|
254 |
else: |
|
|
255 |
self.cuda=False |
|
|
256 |
self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts) |
|
|
257 |
self.n_epoch = n_epoch |
|
|
258 |
self.validation_dataloader = validation_dataloader |
|
|
259 |
self.loss_fn = loss_functions[loss_fn] |
|
|
260 |
self.loss_fn_name = loss_fn |
|
|
261 |
self.bce=(self.loss_fn_name=='bce' or self.validation_dataloader.dataset.mt_bce) |
|
|
262 |
self.sigmoid = nn.Sigmoid() |
|
|
263 |
self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn]) |
|
|
264 |
self.num_train_batches = num_train_batches |
|
|
265 |
self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn]) |
|
|
266 |
self.seg_out_class=seg_out_class |
|
|
267 |
self.checkpointing=checkpointing |
|
|
268 |
self.checkpoint_dir='./checkpoints' |
|
|
269 |
if self.checkpointing: |
|
|
270 |
os.makedirs(self.checkpoint_dir,exist_ok=True) |
|
|
271 |
|
|
|
272 |
def save_model(self, model=None, epoch=0): |
|
|
273 |
torch.save((model if isinstance(model,type(None)) else self.model).state_dict(),os.path.join(self.checkpoint_dir,f'checkpoint.{epoch}.pth')) |
|
|
274 |
|
|
|
275 |
def calc_loss(self, y_pred, y_true): |
|
|
276 |
"""Calculates loss supplied in init statement and modified by reweighting. |
|
|
277 |
|
|
|
278 |
Parameters |
|
|
279 |
---------- |
|
|
280 |
y_pred:tensor |
|
|
281 |
Predictions. |
|
|
282 |
y_true:tensor |
|
|
283 |
True values. |
|
|
284 |
|
|
|
285 |
Returns |
|
|
286 |
------- |
|
|
287 |
loss |
|
|
288 |
|
|
|
289 |
""" |
|
|
290 |
|
|
|
291 |
return self.loss_fn(y_pred, y_true) |
|
|
292 |
|
|
|
293 |
def calc_val_loss(self, y_pred, y_true): |
|
|
294 |
"""Calculates loss supplied in init statement on validation set. |
|
|
295 |
|
|
|
296 |
Parameters |
|
|
297 |
---------- |
|
|
298 |
y_pred:tensor |
|
|
299 |
Predictions. |
|
|
300 |
y_true:tensor |
|
|
301 |
True values. |
|
|
302 |
|
|
|
303 |
Returns |
|
|
304 |
------- |
|
|
305 |
val_loss |
|
|
306 |
|
|
|
307 |
""" |
|
|
308 |
|
|
|
309 |
return self.val_loss_fn(y_pred, y_true) |
|
|
310 |
|
|
|
311 |
def reset_loss_fn(self): |
|
|
312 |
"""Resets loss to original specified loss.""" |
|
|
313 |
self.loss_fn = self.original_loss_fn |
|
|
314 |
|
|
|
315 |
def add_class_balance_loss(self, dataset, custom_weights=''): |
|
|
316 |
"""Updates loss function to handle class imbalance by weighting inverse to class appearance. |
|
|
317 |
|
|
|
318 |
Parameters |
|
|
319 |
---------- |
|
|
320 |
dataset:DynamicImageDataset |
|
|
321 |
Dataset to balance by. |
|
|
322 |
|
|
|
323 |
""" |
|
|
324 |
self.class_weights = dataset.get_class_weights() if not custom_weights else np.array(list(map(float,custom_weights.split(',')))) |
|
|
325 |
if custom_weights: |
|
|
326 |
self.class_weights=self.class_weights/sum(self.class_weights) |
|
|
327 |
print('Weights:',self.class_weights) |
|
|
328 |
self.original_loss_fn = copy.deepcopy(self.loss_fn) |
|
|
329 |
weight=torch.tensor(self.class_weights,dtype=torch.float) |
|
|
330 |
if torch.cuda.is_available(): |
|
|
331 |
weight=weight.cuda() |
|
|
332 |
if self.loss_fn_name=='ce': |
|
|
333 |
self.loss_fn = nn.CrossEntropyLoss(weight=weight) |
|
|
334 |
elif self.loss_fn_name=='nll': |
|
|
335 |
self.loss_fn = nn.NLLLoss(weight=weight) |
|
|
336 |
else: # modify below for multi-target |
|
|
337 |
self.loss_fn = lambda y_pred,y_true: sum([self.class_weights[i]*self.original_loss_fn(y_pred[y_true==i],y_true[y_true==i]) if sum(y_true==i) else 0. for i in range(2)]) |
|
|
338 |
|
|
|
339 |
def calc_best_confusion(self, y_pred, y_true): |
|
|
340 |
"""Calculate confusion matrix on validation set for classification/segmentation tasks, optimize threshold where positive. |
|
|
341 |
|
|
|
342 |
Parameters |
|
|
343 |
---------- |
|
|
344 |
y_pred:array |
|
|
345 |
Predictions. |
|
|
346 |
y_true:array |
|
|
347 |
Ground truth. |
|
|
348 |
|
|
|
349 |
Returns |
|
|
350 |
------- |
|
|
351 |
float |
|
|
352 |
Optimized threshold to use on test set. |
|
|
353 |
dataframe |
|
|
354 |
Confusion matrix. |
|
|
355 |
|
|
|
356 |
""" |
|
|
357 |
fpr, tpr, thresholds = roc_curve(y_true, y_pred) |
|
|
358 |
threshold=thresholds[np.argmin(np.sum((np.array([0,1])-np.vstack((fpr, tpr)).T)**2,axis=1)**.5)] |
|
|
359 |
y_pred = (y_pred>threshold).astype(int) |
|
|
360 |
return threshold, pd.DataFrame(confusion_matrix(y_true,y_pred),index=['F','T'],columns=['-','+']).iloc[::-1,::-1].T |
|
|
361 |
|
|
|
362 |
def loss_backward(self,loss): |
|
|
363 |
"""Backprop using mixed precision for added speed boost. |
|
|
364 |
|
|
|
365 |
Parameters |
|
|
366 |
---------- |
|
|
367 |
loss:loss |
|
|
368 |
Torch loss calculated. |
|
|
369 |
|
|
|
370 |
""" |
|
|
371 |
if self.cuda: |
|
|
372 |
with amp.scale_loss(loss,self.optimizer) as scaled_loss: |
|
|
373 |
scaled_loss.backward() |
|
|
374 |
else: |
|
|
375 |
loss.backward() |
|
|
376 |
|
|
|
377 |
# @pysnooper.snoop('train_loop.log') |
|
|
378 |
def train_loop(self, epoch, train_dataloader): |
|
|
379 |
"""One training epoch, calculate predictions, loss, backpropagate. |
|
|
380 |
|
|
|
381 |
Parameters |
|
|
382 |
---------- |
|
|
383 |
epoch:int |
|
|
384 |
Current epoch. |
|
|
385 |
train_dataloader:DataLoader |
|
|
386 |
Training data. |
|
|
387 |
|
|
|
388 |
Returns |
|
|
389 |
------- |
|
|
390 |
float |
|
|
391 |
Training loss for epoch |
|
|
392 |
|
|
|
393 |
""" |
|
|
394 |
self.model.train(True) |
|
|
395 |
running_loss = 0. |
|
|
396 |
n_batch = len(train_dataloader.dataset)//train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches |
|
|
397 |
for i, batch in enumerate(train_dataloader): |
|
|
398 |
starttime=time.time() |
|
|
399 |
if i == n_batch: |
|
|
400 |
break |
|
|
401 |
X = Variable(batch[0], requires_grad=True) |
|
|
402 |
y_true = Variable(batch[1]) |
|
|
403 |
if not train_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1: |
|
|
404 |
y_true=y_true.argmax(1).long() |
|
|
405 |
if train_dataloader.dataset.segmentation and self.loss_fn_name!='dice': |
|
|
406 |
y_true=y_true.squeeze(1) |
|
|
407 |
if torch.cuda.is_available(): |
|
|
408 |
X = X.cuda() |
|
|
409 |
y_true=y_true.cuda() |
|
|
410 |
y_pred = self.model(X) |
|
|
411 |
#sizes=(y_pred.size(),y_true.size()) |
|
|
412 |
#print(y_true) |
|
|
413 |
loss = self.calc_loss(y_pred,y_true) |
|
|
414 |
train_loss=loss.item() |
|
|
415 |
running_loss += train_loss |
|
|
416 |
self.optimizer.zero_grad() |
|
|
417 |
self.loss_backward(loss)#loss.backward() |
|
|
418 |
self.optimizer.step() |
|
|
419 |
endtime=time.time() |
|
|
420 |
print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,i,n_batch,round(endtime-starttime,3),train_loss)) |
|
|
421 |
self.scheduler.step() |
|
|
422 |
running_loss/=n_batch |
|
|
423 |
return running_loss |
|
|
424 |
|
|
|
425 |
def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True): |
|
|
426 |
"""Calculate loss over validation set. |
|
|
427 |
|
|
|
428 |
Parameters |
|
|
429 |
---------- |
|
|
430 |
epoch:int |
|
|
431 |
Current epoch. |
|
|
432 |
val_dataloader:DataLoader |
|
|
433 |
Validation iterator. |
|
|
434 |
print_val_confusion:bool |
|
|
435 |
Calculate confusion matrix and plot. |
|
|
436 |
save_predictions:int |
|
|
437 |
Print validation results. |
|
|
438 |
|
|
|
439 |
Returns |
|
|
440 |
------- |
|
|
441 |
float |
|
|
442 |
Validation loss for epoch. |
|
|
443 |
""" |
|
|
444 |
self.model.train(False) |
|
|
445 |
n_batch = len(val_dataloader.dataset)//val_dataloader.batch_size |
|
|
446 |
running_loss = 0. |
|
|
447 |
Y = {'pred':[],'true':[]} |
|
|
448 |
with torch.no_grad(): |
|
|
449 |
for i, batch in enumerate(val_dataloader): |
|
|
450 |
X = Variable(batch[0],requires_grad=False) |
|
|
451 |
y_true = Variable(batch[1]) |
|
|
452 |
if not val_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1: |
|
|
453 |
y_true=y_true.argmax(1).long() |
|
|
454 |
if val_dataloader.dataset.segmentation and self.loss_fn_name!='dice': |
|
|
455 |
y_true=y_true.squeeze(1) |
|
|
456 |
if torch.cuda.is_available(): |
|
|
457 |
X = X.cuda() |
|
|
458 |
y_true=y_true.cuda() |
|
|
459 |
y_pred = self.model(X) |
|
|
460 |
if save_predictions: |
|
|
461 |
if val_dataloader.dataset.segmentation: |
|
|
462 |
Y['true'].append(torch.flatten(y_true if not val_dataloader.dataset.gdl else y_true).detach().cpu().numpy().astype(int).flatten()) # .argmax(axis=1) |
|
|
463 |
Y['pred'].append((y_pred.detach().cpu().numpy().argmax(axis=1)).astype(int).flatten()) |
|
|
464 |
else: |
|
|
465 |
Y['true'].append(y_true.detach().cpu().numpy().astype(int).flatten()) |
|
|
466 |
y_pred_numpy=((y_pred if not self.bce else self.sigmoid(y_pred)).detach().cpu().numpy()).astype(float) |
|
|
467 |
if len(y_pred_numpy)>1 and y_pred_numpy.shape[1]>1 and not val_dataloader.dataset.mt_bce: |
|
|
468 |
y_pred_numpy=y_pred_numpy.argmax(axis=1) |
|
|
469 |
Y['pred'].append(y_pred_numpy.flatten()) |
|
|
470 |
loss = self.calc_val_loss(y_pred,y_true) |
|
|
471 |
val_loss=loss.item() |
|
|
472 |
running_loss += val_loss |
|
|
473 |
print("Epoch {}[{}/{}] Val Loss:{}".format(epoch,i,n_batch,val_loss)) |
|
|
474 |
if print_val_confusion and save_predictions: |
|
|
475 |
y_pred,y_true = np.hstack(Y['pred']),np.hstack(Y['true']) |
|
|
476 |
if not val_dataloader.dataset.segmentation: |
|
|
477 |
if self.loss_fn_name in ['bce','mse'] and not val_dataloader.dataset.mt_bce: |
|
|
478 |
threshold, best_confusion = self.calc_best_confusion(y_pred,y_true) |
|
|
479 |
print("Epoch {} Val Confusion, Threshold {}:".format(epoch,threshold)) |
|
|
480 |
print(best_confusion) |
|
|
481 |
y_true = y_true.astype(int) |
|
|
482 |
y_pred = (y_pred>=threshold).astype(int) |
|
|
483 |
elif val_dataloader.dataset.mt_bce: |
|
|
484 |
n_targets = len(val_dataloader.dataset.targets) |
|
|
485 |
y_pred=y_pred[y_true>0] |
|
|
486 |
y_true=y_true[y_true>0] |
|
|
487 |
y_true=y_true[np.isnan(y_pred)==False] |
|
|
488 |
y_pred=y_pred[np.isnan(y_pred)==False] |
|
|
489 |
if 0 and n_targets > 1: |
|
|
490 |
n_row=len(y_true)/n_targets |
|
|
491 |
y_pred=y_pred.reshape(int(n_row),n_targets) |
|
|
492 |
y_true=y_true.reshape(int(n_row),n_targets) |
|
|
493 |
print("Epoch {} Val Regression, R2 Score {}".format(epoch, str(r2_score(y_true, y_pred)))) |
|
|
494 |
else: |
|
|
495 |
print(classification_report(y_true,y_pred)) |
|
|
496 |
|
|
|
497 |
running_loss/=n_batch |
|
|
498 |
return running_loss |
|
|
499 |
|
|
|
500 |
#@pysnooper.snoop("test_loop.log") |
|
|
501 |
def test_loop(self, test_dataloader): |
|
|
502 |
"""Calculate final predictions on loss. |
|
|
503 |
|
|
|
504 |
Parameters |
|
|
505 |
---------- |
|
|
506 |
test_dataloader:DataLoader |
|
|
507 |
Test dataset. |
|
|
508 |
|
|
|
509 |
Returns |
|
|
510 |
------- |
|
|
511 |
array |
|
|
512 |
Predictions or embeddings. |
|
|
513 |
""" |
|
|
514 |
#self.model.train(False) KEEP DROPOUT? and BATCH NORM?? |
|
|
515 |
y_pred = [] |
|
|
516 |
running_loss = 0. |
|
|
517 |
with torch.no_grad(): |
|
|
518 |
for i, (X,y_test) in enumerate(test_dataloader): |
|
|
519 |
#X = Variable(batch[0],requires_grad=False) |
|
|
520 |
if torch.cuda.is_available(): |
|
|
521 |
X = X.cuda() |
|
|
522 |
if test_dataloader.dataset.segmentation: |
|
|
523 |
prediction=self.model(X).detach().cpu().numpy() |
|
|
524 |
if self.seg_out_class>=0: |
|
|
525 |
prediction=prediction[:,self.seg_out_class,...] |
|
|
526 |
else: |
|
|
527 |
prediction=prediction.argmax(axis=1).astype(int) |
|
|
528 |
pred_size=prediction.shape#size() |
|
|
529 |
#pred_mean=prediction[0].mean(axis=0) |
|
|
530 |
y_pred.append(prediction) |
|
|
531 |
else: |
|
|
532 |
prediction=self.model(X) |
|
|
533 |
if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce): |
|
|
534 |
prediction=self.sigmoid(prediction) |
|
|
535 |
elif test_dataloader.dataset.classify_annotations: |
|
|
536 |
prediction=F.softmax(prediction,dim=1) |
|
|
537 |
y_pred.append(prediction.detach().cpu().numpy()) |
|
|
538 |
y_pred = np.concatenate(y_pred,axis=0)#torch.cat(y_pred,0) |
|
|
539 |
|
|
|
540 |
return y_pred |
|
|
541 |
|
|
|
542 |
def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=False, plot_save_file=None, print_val_confusion=True, save_val_predictions=True): |
|
|
543 |
"""Fits the segmentation or classification model to the patches, saving the model with the lowest validation score. |
|
|
544 |
|
|
|
545 |
Parameters |
|
|
546 |
---------- |
|
|
547 |
train_dataloader:DataLoader |
|
|
548 |
Training dataset. |
|
|
549 |
verbose:bool |
|
|
550 |
Print training and validation loss? |
|
|
551 |
print_every:int |
|
|
552 |
Number of epochs until print? |
|
|
553 |
save_model:bool |
|
|
554 |
Whether to save model when reaching lowest validation loss. |
|
|
555 |
plot_training_curves:bool |
|
|
556 |
Plot training curves over epochs. |
|
|
557 |
plot_save_file:str |
|
|
558 |
File to save training curves. |
|
|
559 |
print_val_confusion:bool |
|
|
560 |
Print validation confusion matrix. |
|
|
561 |
save_val_predictions:bool |
|
|
562 |
Print validation results. |
|
|
563 |
|
|
|
564 |
Returns |
|
|
565 |
------- |
|
|
566 |
self |
|
|
567 |
Trainer. |
|
|
568 |
float |
|
|
569 |
Minimum val loss. |
|
|
570 |
int |
|
|
571 |
Best validation epoch with lowest loss. |
|
|
572 |
|
|
|
573 |
""" |
|
|
574 |
# choose model with best f1 |
|
|
575 |
self.train_losses = [] |
|
|
576 |
self.val_losses = [] |
|
|
577 |
for epoch in range(self.n_epoch): |
|
|
578 |
start_time=time.time() |
|
|
579 |
train_loss = self.train_loop(epoch,train_dataloader) |
|
|
580 |
current_time=time.time() |
|
|
581 |
train_time=current_time-start_time |
|
|
582 |
self.train_losses.append(train_loss) |
|
|
583 |
val_loss = self.val_loop(epoch,self.validation_dataloader, print_val_confusion=print_val_confusion, save_predictions=save_val_predictions) |
|
|
584 |
val_time=time.time()-current_time |
|
|
585 |
self.val_losses.append(val_loss) |
|
|
586 |
if verbose and not (epoch % print_every): |
|
|
587 |
if plot_training_curves: |
|
|
588 |
self.plot_train_val_curves(plot_save_file) |
|
|
589 |
print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(epoch,train_loss,val_loss,train_time,val_time)) |
|
|
590 |
if val_loss <= min(self.val_losses) and save_model: |
|
|
591 |
min_val_loss = val_loss |
|
|
592 |
best_epoch = epoch |
|
|
593 |
best_model = copy.deepcopy(self.model) |
|
|
594 |
if self.checkpointing: |
|
|
595 |
self.save_model(best_model,epoch) |
|
|
596 |
if save_model: |
|
|
597 |
self.model = best_model |
|
|
598 |
return self, min_val_loss, best_epoch |
|
|
599 |
|
|
|
600 |
def plot_train_val_curves(self, save_file=None): |
|
|
601 |
"""Plots training and validation curves. |
|
|
602 |
|
|
|
603 |
Parameters |
|
|
604 |
---------- |
|
|
605 |
save_file:str |
|
|
606 |
File to save to. |
|
|
607 |
|
|
|
608 |
""" |
|
|
609 |
plt.figure() |
|
|
610 |
sns.lineplot('epoch','value',hue='variable', |
|
|
611 |
data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)),self.train_losses,self.val_losses)).T, |
|
|
612 |
columns=['epoch','train','val']).melt(id_vars=['epoch'],value_vars=['train','val'])) |
|
|
613 |
if save_file is not None: |
|
|
614 |
plt.savefig(save_file, dpi=300) |
|
|
615 |
|
|
|
616 |
def predict(self, test_dataloader): |
|
|
617 |
"""Make classification segmentation predictions on testing data. |
|
|
618 |
|
|
|
619 |
Parameters |
|
|
620 |
---------- |
|
|
621 |
test_dataloader:DataLoader |
|
|
622 |
Test data. |
|
|
623 |
|
|
|
624 |
Returns |
|
|
625 |
------- |
|
|
626 |
array |
|
|
627 |
Predictions. |
|
|
628 |
|
|
|
629 |
""" |
|
|
630 |
y_pred = self.test_loop(test_dataloader) |
|
|
631 |
return y_pred |
|
|
632 |
|
|
|
633 |
def fit_predict(self, train_dataloader, test_dataloader): |
|
|
634 |
"""Fit model to training data and make classification segmentation predictions on testing data. |
|
|
635 |
|
|
|
636 |
Parameters |
|
|
637 |
---------- |
|
|
638 |
train_dataloader:DataLoader |
|
|
639 |
Train data. |
|
|
640 |
test_dataloader:DataLoader |
|
|
641 |
Test data. |
|
|
642 |
|
|
|
643 |
Returns |
|
|
644 |
------- |
|
|
645 |
array |
|
|
646 |
Predictions. |
|
|
647 |
|
|
|
648 |
""" |
|
|
649 |
return self.fit(train_dataloader)[0].predict(test_dataloader) |
|
|
650 |
|
|
|
651 |
def return_model(self): |
|
|
652 |
"""Returns pytorch model. |
|
|
653 |
""" |
|
|
654 |
return self.model |