Diff of /pathflowai/models.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/models.py
1
"""
2
models.py
3
=======================
4
Houses all of the PyTorch models to access and the corresponding Scikit-Learn like model trainer.
5
"""
6
from pathflowai.unet import UNet
7
# from pathflowai.unet2 import NestedUNet
8
# from pathflowai.unet4 import UNetSmall as UNet2
9
from pathflowai.fast_scnn import get_fast_scnn
10
import torch
11
import torchvision
12
from torchvision import models
13
from torchvision.models import segmentation as segmodels
14
from torch import nn
15
from torch.nn import functional as F
16
import pandas as pd, numpy as np
17
import matplotlib
18
matplotlib.use('Agg')
19
import matplotlib.pyplot as plt
20
import seaborn as sns
21
from pathflowai.schedulers import *
22
import pysnooper
23
from torch.autograd import Variable
24
import copy
25
from sklearn.metrics import roc_curve, confusion_matrix, classification_report, r2_score
26
sns.set()
27
from pathflowai.losses import GeneralizedDiceLoss, FocalLoss
28
from apex import amp
29
from torch.nn import functional as F
30
import time, os
31
32
class MLP(nn.Module):
33
    """Multi-layer perceptron model.
34
35
    Parameters
36
    ----------
37
    n_input:int
38
        Number input dimensions.
39
    hidden_topology:list
40
        List of hidden topology
41
    dropout_p:float
42
        Amount dropout.
43
    n_outputs:int
44
        Number outputs.
45
    binary:bool
46
        Binary output with sigmoid transform.
47
    softmax:bool
48
        Whether to apply softmax on output.
49
50
    """
51
    def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True, softmax=False):
52
        super(MLP,self).__init__()
53
        self.topology = [n_input]+hidden_topology+[n_outputs]
54
        layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)]
55
        for layer in layers:
56
            torch.nn.init.xavier_uniform_(layer.weight)
57
        self.layers = [nn.Sequential(layer,nn.LeakyReLU(),nn.Dropout(p=dropout_p)) for layer in layers]
58
        self.output_layer = nn.Linear(self.topology[-2],self.topology[-1])
59
        torch.nn.init.xavier_uniform_(self.output_layer.weight)
60
        if binary:
61
            output_transform = nn.Sigmoid()
62
        elif softmax:
63
            output_transform = nn.Softmax()
64
        else:
65
            output_transform = nn.Dropout(p=0.)
66
        self.layers.append(nn.Sequential(self.output_layer,output_transform))
67
        self.mlp = nn.Sequential(*self.layers)
68
69
    def forward(self,x):
70
        return self.mlp(x)
71
72
class FixedSegmentationModule(nn.Module):
73
    """Special model modification for segmentation tasks. Gets output from some of the models' forward loops.
74
75
    Parameters
76
    ----------
77
    segnet:nn.Module
78
        Segmentation network
79
    """
80
    def __init__(self, segnet):
81
        super(FixedSegmentationModule, self).__init__()
82
        self.segnet=segnet
83
84
    def forward(self, x):
85
        """Forward pass.
86
87
        Parameters
88
        ----------
89
        x:Tensor
90
            Input
91
92
        Returns
93
        -------
94
        Tensor
95
            Output from model.
96
97
        """
98
        return self.segnet(x)['out']
99
100
def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden=100, segmentation=False):
101
    """Generate a nn.Module for use.
102
103
    Parameters
104
    ----------
105
    pretrain:bool
106
        Pretrain using ImageNet?
107
    architecture:str
108
        See model_training for list of all architectures you can train with.
109
    num_classes:int
110
        Number of classes to predict.
111
    add_sigmoid:type
112
        Add sigmoid non-linearity at end.
113
    n_hidden:int
114
        Number of hidden fully connected layers.
115
    segmentation:bool
116
        Whether segment task?
117
118
    Returns
119
    -------
120
    nn.Module
121
        Pytorch model.
122
123
    """
124
    # to add: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/models/model_zoo.py
125
    #architecture = 'resnet' + str(num_layers)
126
    model = None
127
128
    if architecture =='unet':
129
        model = UNet(n_channels=3, n_classes=num_classes)
130
    elif architecture =='unet2':
131
        print('Deprecated for now, defaulting to UNET.')
132
        model = UNet(n_channels=3, n_classes=num_classes)#UNet2(3,num_classes)
133
    elif architecture == 'fast_scnn':
134
        model = get_fast_scnn(num_classes)
135
    elif architecture == 'nested_unet':
136
        print('Nested UNET is deprecated for now, defaulting to UNET.')
137
        model = UNet(n_channels=3, n_classes=num_classes)#NestedUNet(3, num_classes)
138
    elif architecture.startswith('efficientnet'):
139
        from efficientnet_pytorch import EfficientNet
140
        if pretrain:
141
            model = EfficientNet.from_pretrained(architecture, override_params=dict(num_classes=num_classes))
142
        else:
143
            model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes))
144
        print(model)
145
    elif architecture.startswith('sqnxt'):
146
        from pytorchcv.model_provider import get_model as ptcv_get_model
147
        model = ptcv_get_model(architecture, pretrained=pretrain)
148
        num_ftrs=int(128*int(architecture.split('_')[-1][1]))
149
        model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp
150
    else:
151
        #for pretrained on imagenet
152
        model_names = [m for m in dir(models) if not m.startswith('__')]
153
        segmentation_model_names = [m for m in dir(segmodels) if not m.startswith('__')]
154
        if architecture in model_names:
155
            model = getattr(models, architecture)(pretrained=pretrain)
156
        if segmentation:
157
            if architecture in segmentation_model_names:
158
                model = getattr(segmodels, architecture)(pretrained=pretrain)
159
            else:
160
                model = UNet(n_channels=3, n_classes=num_classes)
161
            if architecture.startswith('deeplab'):
162
                model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
163
                model = FixedSegmentationModule(model)
164
            elif architecture.startswith('fcn'):
165
                model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
166
                model = FixedSegmentationModule(model)
167
        elif architecture.startswith('resnet') or architecture.startswith('inception'):
168
            num_ftrs = model.fc.in_features
169
            #linear_layer = nn.Linear(num_ftrs, num_classes)
170
            #torch.nn.init.xavier_uniform(linear_layer.weight)
171
            model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else [])))
172
        elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
173
            num_ftrs = model.classifier[6].in_features
174
            #linear_layer = nn.Linear(num_ftrs, num_classes)
175
            #torch.nn.init.xavier_uniform(linear_layer.weight)
176
            model.classifier[6] = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else [])))
177
    return model
178
179
#@pysnooper.snoop("dice_loss.log")
180
def dice_loss(logits, true, eps=1e-7):
181
    """https://github.com/kevinzakka/pytorch-goodies
182
    Computes the Sørensen–Dice loss.
183
184
    Note that PyTorch optimizers minimize a loss. In this
185
    case, we would like to maximize the dice loss so we
186
    return the negated dice loss.
187
188
    Args:
189
        true: a tensor of shape [B, 1, H, W].
190
        logits: a tensor of shape [B, C, H, W]. Corresponds to
191
            the raw output or logits of the model.
192
        eps: added to the denominator for numerical stability.
193
194
    Returns:
195
        dice_loss: the Sørensen–Dice loss.
196
    """
197
    #true=true.long()
198
    num_classes = logits.shape[1]
199
    if num_classes == 1:
200
        true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
201
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
202
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
203
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
204
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
205
        pos_prob = torch.sigmoid(logits)
206
        neg_prob = 1 - pos_prob
207
        probas = torch.cat([pos_prob, neg_prob], dim=1)
208
    else:
209
        true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
210
        #print(true_1_hot.size())
211
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
212
        probas = F.softmax(logits, dim=1)
213
    true_1_hot = true_1_hot.type(logits.type())
214
    dims = (0,) + tuple(range(2, true.ndimension()))
215
    intersection = torch.sum(probas * true_1_hot, dims)
216
    cardinality = torch.sum(probas + true_1_hot, dims)
217
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
218
    return (1 - dice_loss)
219
220
class ModelTrainer:
221
    """Trainer for the neural network model that wraps it into a scikit-learn like interface.
222
223
    Parameters
224
    ----------
225
    model:nn.Module
226
        Deep learning pytorch model.
227
    n_epoch:int
228
        Number training epochs.
229
    validation_dataloader:DataLoader
230
        Dataloader of validation dataset.
231
    optimizer_opts:dict
232
        Options for optimizer.
233
    scheduler_opts:dict
234
        Options for learning rate scheduler.
235
    loss_fn:str
236
        String to call a particular loss function for model.
237
    reduction:str
238
        Mean or sum reduction of loss.
239
    num_train_batches:int
240
        Number of training batches for epoch.
241
    """
242
    def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1, apex_opt_level="O2", checkpointing=False):
243
244
        self.model = model
245
        optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD}
246
        loss_functions = {'bce':nn.BCEWithLogitsLoss(reduction=reduction), 'ce':nn.CrossEntropyLoss(reduction=reduction), 'mse':nn.MSELoss(reduction=reduction), 'nll':nn.NLLLoss(reduction=reduction), 'dice':dice_loss, 'focal':FocalLoss(num_class=2), 'gdl':GeneralizedDiceLoss(add_softmax=True)}
247
        loss_functions['dice+ce']=(lambda y_pred, y_true: dice_loss(y_pred,y_true)+loss_functions['ce'](y_pred,y_true))
248
        if 'name' not in list(optimizer_opts.keys()):
249
            optimizer_opts['name']='adam'
250
        self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts)
251
        if torch.cuda.is_available():
252
            self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=apex_opt_level)
253
            self.cuda=True
254
        else:
255
            self.cuda=False
256
        self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts)
257
        self.n_epoch = n_epoch
258
        self.validation_dataloader = validation_dataloader
259
        self.loss_fn = loss_functions[loss_fn]
260
        self.loss_fn_name = loss_fn
261
        self.bce=(self.loss_fn_name=='bce' or self.validation_dataloader.dataset.mt_bce)
262
        self.sigmoid = nn.Sigmoid()
263
        self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn])
264
        self.num_train_batches = num_train_batches
265
        self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn])
266
        self.seg_out_class=seg_out_class
267
        self.checkpointing=checkpointing
268
        self.checkpoint_dir='./checkpoints'
269
        if self.checkpointing:
270
            os.makedirs(self.checkpoint_dir,exist_ok=True)
271
272
    def save_model(self, model=None, epoch=0):
273
        torch.save((model if isinstance(model,type(None)) else self.model).state_dict(),os.path.join(self.checkpoint_dir,f'checkpoint.{epoch}.pth'))
274
275
    def calc_loss(self, y_pred, y_true):
276
        """Calculates loss supplied in init statement and modified by reweighting.
277
278
        Parameters
279
        ----------
280
        y_pred:tensor
281
            Predictions.
282
        y_true:tensor
283
            True values.
284
285
        Returns
286
        -------
287
        loss
288
289
        """
290
291
        return self.loss_fn(y_pred, y_true)
292
293
    def calc_val_loss(self, y_pred, y_true):
294
        """Calculates loss supplied in init statement on validation set.
295
296
        Parameters
297
        ----------
298
        y_pred:tensor
299
            Predictions.
300
        y_true:tensor
301
            True values.
302
303
        Returns
304
        -------
305
        val_loss
306
307
        """
308
309
        return self.val_loss_fn(y_pred, y_true)
310
311
    def reset_loss_fn(self):
312
        """Resets loss to original specified loss."""
313
        self.loss_fn = self.original_loss_fn
314
315
    def add_class_balance_loss(self, dataset, custom_weights=''):
316
        """Updates loss function to handle class imbalance by weighting inverse to class appearance.
317
318
        Parameters
319
        ----------
320
        dataset:DynamicImageDataset
321
            Dataset to balance by.
322
323
        """
324
        self.class_weights = dataset.get_class_weights() if not custom_weights else np.array(list(map(float,custom_weights.split(','))))
325
        if custom_weights:
326
            self.class_weights=self.class_weights/sum(self.class_weights)
327
        print('Weights:',self.class_weights)
328
        self.original_loss_fn = copy.deepcopy(self.loss_fn)
329
        weight=torch.tensor(self.class_weights,dtype=torch.float)
330
        if torch.cuda.is_available():
331
            weight=weight.cuda()
332
        if self.loss_fn_name=='ce':
333
            self.loss_fn = nn.CrossEntropyLoss(weight=weight)
334
        elif self.loss_fn_name=='nll':
335
            self.loss_fn = nn.NLLLoss(weight=weight)
336
        else: # modify below for multi-target
337
            self.loss_fn = lambda y_pred,y_true: sum([self.class_weights[i]*self.original_loss_fn(y_pred[y_true==i],y_true[y_true==i]) if sum(y_true==i) else 0. for i in range(2)])
338
339
    def calc_best_confusion(self, y_pred, y_true):
340
        """Calculate confusion matrix on validation set for classification/segmentation tasks, optimize threshold where positive.
341
342
        Parameters
343
        ----------
344
        y_pred:array
345
            Predictions.
346
        y_true:array
347
            Ground truth.
348
349
        Returns
350
        -------
351
        float
352
            Optimized threshold to use on test set.
353
        dataframe
354
            Confusion matrix.
355
356
        """
357
        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
358
        threshold=thresholds[np.argmin(np.sum((np.array([0,1])-np.vstack((fpr, tpr)).T)**2,axis=1)**.5)]
359
        y_pred = (y_pred>threshold).astype(int)
360
        return threshold, pd.DataFrame(confusion_matrix(y_true,y_pred),index=['F','T'],columns=['-','+']).iloc[::-1,::-1].T
361
362
    def loss_backward(self,loss):
363
        """Backprop using mixed precision for added speed boost.
364
365
        Parameters
366
        ----------
367
        loss:loss
368
            Torch loss calculated.
369
370
        """
371
        if self.cuda:
372
            with amp.scale_loss(loss,self.optimizer) as scaled_loss:
373
                scaled_loss.backward()
374
        else:
375
            loss.backward()
376
377
    # @pysnooper.snoop('train_loop.log')
378
    def train_loop(self, epoch, train_dataloader):
379
        """One training epoch, calculate predictions, loss, backpropagate.
380
381
        Parameters
382
        ----------
383
        epoch:int
384
            Current epoch.
385
        train_dataloader:DataLoader
386
            Training data.
387
388
        Returns
389
        -------
390
        float
391
            Training loss for epoch
392
393
        """
394
        self.model.train(True)
395
        running_loss = 0.
396
        n_batch = len(train_dataloader.dataset)//train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches
397
        for i, batch in enumerate(train_dataloader):
398
            starttime=time.time()
399
            if i == n_batch:
400
                break
401
            X = Variable(batch[0], requires_grad=True)
402
            y_true = Variable(batch[1])
403
            if not train_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1:
404
                y_true=y_true.argmax(1).long()
405
            if train_dataloader.dataset.segmentation and self.loss_fn_name!='dice':
406
                y_true=y_true.squeeze(1)
407
            if torch.cuda.is_available():
408
                X = X.cuda()
409
                y_true=y_true.cuda()
410
            y_pred = self.model(X)
411
            #sizes=(y_pred.size(),y_true.size())
412
            #print(y_true)
413
            loss = self.calc_loss(y_pred,y_true)
414
            train_loss=loss.item()
415
            running_loss += train_loss
416
            self.optimizer.zero_grad()
417
            self.loss_backward(loss)#loss.backward()
418
            self.optimizer.step()
419
            endtime=time.time()
420
            print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,i,n_batch,round(endtime-starttime,3),train_loss))
421
        self.scheduler.step()
422
        running_loss/=n_batch
423
        return running_loss
424
425
    def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True):
426
        """Calculate loss over validation set.
427
428
        Parameters
429
        ----------
430
        epoch:int
431
            Current epoch.
432
        val_dataloader:DataLoader
433
            Validation iterator.
434
        print_val_confusion:bool
435
            Calculate confusion matrix and plot.
436
        save_predictions:int
437
            Print validation results.
438
439
        Returns
440
        -------
441
        float
442
            Validation loss for epoch.
443
        """
444
        self.model.train(False)
445
        n_batch = len(val_dataloader.dataset)//val_dataloader.batch_size
446
        running_loss = 0.
447
        Y = {'pred':[],'true':[]}
448
        with torch.no_grad():
449
            for i, batch in enumerate(val_dataloader):
450
                X = Variable(batch[0],requires_grad=False)
451
                y_true = Variable(batch[1])
452
                if not val_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1:
453
                    y_true=y_true.argmax(1).long()
454
                if val_dataloader.dataset.segmentation and self.loss_fn_name!='dice':
455
                    y_true=y_true.squeeze(1)
456
                if torch.cuda.is_available():
457
                    X = X.cuda()
458
                    y_true=y_true.cuda()
459
                y_pred = self.model(X)
460
                if save_predictions:
461
                    if val_dataloader.dataset.segmentation:
462
                        Y['true'].append(torch.flatten(y_true if not val_dataloader.dataset.gdl else y_true).detach().cpu().numpy().astype(int).flatten()) # .argmax(axis=1)
463
                        Y['pred'].append((y_pred.detach().cpu().numpy().argmax(axis=1)).astype(int).flatten())
464
                    else:
465
                        Y['true'].append(y_true.detach().cpu().numpy().astype(int).flatten())
466
                        y_pred_numpy=((y_pred if not self.bce else self.sigmoid(y_pred)).detach().cpu().numpy()).astype(float)
467
                        if len(y_pred_numpy)>1 and y_pred_numpy.shape[1]>1 and not val_dataloader.dataset.mt_bce:
468
                            y_pred_numpy=y_pred_numpy.argmax(axis=1)
469
                        Y['pred'].append(y_pred_numpy.flatten())
470
                loss = self.calc_val_loss(y_pred,y_true)
471
                val_loss=loss.item()
472
                running_loss += val_loss
473
                print("Epoch {}[{}/{}] Val Loss:{}".format(epoch,i,n_batch,val_loss))
474
        if print_val_confusion and save_predictions:
475
            y_pred,y_true = np.hstack(Y['pred']),np.hstack(Y['true'])
476
            if not val_dataloader.dataset.segmentation:
477
                if self.loss_fn_name in ['bce','mse'] and not val_dataloader.dataset.mt_bce:
478
                    threshold, best_confusion = self.calc_best_confusion(y_pred,y_true)
479
                    print("Epoch {} Val Confusion, Threshold {}:".format(epoch,threshold))
480
                    print(best_confusion)
481
                    y_true = y_true.astype(int)
482
                    y_pred = (y_pred>=threshold).astype(int)
483
                elif val_dataloader.dataset.mt_bce:
484
                    n_targets = len(val_dataloader.dataset.targets)
485
                    y_pred=y_pred[y_true>0]
486
                    y_true=y_true[y_true>0]
487
                    y_true=y_true[np.isnan(y_pred)==False]
488
                    y_pred=y_pred[np.isnan(y_pred)==False]
489
                    if 0 and n_targets > 1:
490
                        n_row=len(y_true)/n_targets
491
                        y_pred=y_pred.reshape(int(n_row),n_targets)
492
                        y_true=y_true.reshape(int(n_row),n_targets)
493
                    print("Epoch {} Val Regression, R2 Score {}".format(epoch, str(r2_score(y_true, y_pred))))
494
            else:
495
                print(classification_report(y_true,y_pred))
496
497
        running_loss/=n_batch
498
        return running_loss
499
500
    #@pysnooper.snoop("test_loop.log")
501
    def test_loop(self, test_dataloader):
502
        """Calculate final predictions on loss.
503
504
        Parameters
505
        ----------
506
        test_dataloader:DataLoader
507
            Test dataset.
508
509
        Returns
510
        -------
511
        array
512
            Predictions or embeddings.
513
        """
514
        #self.model.train(False) KEEP DROPOUT? and BATCH NORM??
515
        y_pred = []
516
        running_loss = 0.
517
        with torch.no_grad():
518
            for i, (X,y_test) in enumerate(test_dataloader):
519
                #X = Variable(batch[0],requires_grad=False)
520
                if torch.cuda.is_available():
521
                    X = X.cuda()
522
                if test_dataloader.dataset.segmentation:
523
                    prediction=self.model(X).detach().cpu().numpy()
524
                    if self.seg_out_class>=0:
525
                        prediction=prediction[:,self.seg_out_class,...]
526
                    else:
527
                        prediction=prediction.argmax(axis=1).astype(int)
528
                    pred_size=prediction.shape#size()
529
                    #pred_mean=prediction[0].mean(axis=0)
530
                    y_pred.append(prediction)
531
                else:
532
                    prediction=self.model(X)
533
                    if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce):
534
                        prediction=self.sigmoid(prediction)
535
                    elif test_dataloader.dataset.classify_annotations:
536
                        prediction=F.softmax(prediction,dim=1)
537
                    y_pred.append(prediction.detach().cpu().numpy())
538
        y_pred = np.concatenate(y_pred,axis=0)#torch.cat(y_pred,0)
539
540
        return y_pred
541
542
    def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=False, plot_save_file=None, print_val_confusion=True, save_val_predictions=True):
543
        """Fits the segmentation or classification model to the patches, saving the model with the lowest validation score.
544
545
        Parameters
546
        ----------
547
        train_dataloader:DataLoader
548
            Training dataset.
549
        verbose:bool
550
            Print training and validation loss?
551
        print_every:int
552
            Number of epochs until print?
553
        save_model:bool
554
            Whether to save model when reaching lowest validation loss.
555
        plot_training_curves:bool
556
            Plot training curves over epochs.
557
        plot_save_file:str
558
            File to save training curves.
559
        print_val_confusion:bool
560
            Print validation confusion matrix.
561
        save_val_predictions:bool
562
            Print validation results.
563
564
        Returns
565
        -------
566
        self
567
            Trainer.
568
        float
569
            Minimum val loss.
570
        int
571
            Best validation epoch with lowest loss.
572
573
        """
574
        # choose model with best f1
575
        self.train_losses = []
576
        self.val_losses = []
577
        for epoch in range(self.n_epoch):
578
            start_time=time.time()
579
            train_loss = self.train_loop(epoch,train_dataloader)
580
            current_time=time.time()
581
            train_time=current_time-start_time
582
            self.train_losses.append(train_loss)
583
            val_loss = self.val_loop(epoch,self.validation_dataloader, print_val_confusion=print_val_confusion, save_predictions=save_val_predictions)
584
            val_time=time.time()-current_time
585
            self.val_losses.append(val_loss)
586
            if verbose and not (epoch % print_every):
587
                if plot_training_curves:
588
                    self.plot_train_val_curves(plot_save_file)
589
                print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(epoch,train_loss,val_loss,train_time,val_time))
590
            if val_loss <= min(self.val_losses) and save_model:
591
                min_val_loss = val_loss
592
                best_epoch = epoch
593
                best_model = copy.deepcopy(self.model)
594
                if self.checkpointing:
595
                    self.save_model(best_model,epoch)
596
        if save_model:
597
            self.model = best_model
598
        return self, min_val_loss, best_epoch
599
600
    def plot_train_val_curves(self, save_file=None):
601
        """Plots training and validation curves.
602
603
        Parameters
604
        ----------
605
        save_file:str
606
            File to save to.
607
608
        """
609
        plt.figure()
610
        sns.lineplot('epoch','value',hue='variable',
611
                     data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)),self.train_losses,self.val_losses)).T,
612
                                       columns=['epoch','train','val']).melt(id_vars=['epoch'],value_vars=['train','val']))
613
        if save_file is not None:
614
            plt.savefig(save_file, dpi=300)
615
616
    def predict(self, test_dataloader):
617
        """Make classification segmentation predictions on testing data.
618
619
        Parameters
620
        ----------
621
        test_dataloader:DataLoader
622
            Test data.
623
624
        Returns
625
        -------
626
        array
627
            Predictions.
628
629
        """
630
        y_pred = self.test_loop(test_dataloader)
631
        return y_pred
632
633
    def fit_predict(self, train_dataloader, test_dataloader):
634
        """Fit model to training data and make classification segmentation predictions on testing data.
635
636
        Parameters
637
        ----------
638
        train_dataloader:DataLoader
639
            Train data.
640
        test_dataloader:DataLoader
641
            Test data.
642
643
        Returns
644
        -------
645
        array
646
            Predictions.
647
648
        """
649
        return self.fit(train_dataloader)[0].predict(test_dataloader)
650
651
    def return_model(self):
652
        """Returns pytorch model.
653
        """
654
        return self.model