"""
models.py
=======================
Houses all of the PyTorch models to access and the corresponding Scikit-Learn like model trainer.
"""
from pathflowai.unet import UNet
# from pathflowai.unet2 import NestedUNet
# from pathflowai.unet4 import UNetSmall as UNet2
from pathflowai.fast_scnn import get_fast_scnn
import torch
import torchvision
from torchvision import models
from torchvision.models import segmentation as segmodels
from torch import nn
from torch.nn import functional as F
import pandas as pd, numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
from pathflowai.schedulers import *
import pysnooper
from torch.autograd import Variable
import copy
from sklearn.metrics import roc_curve, confusion_matrix, classification_report, r2_score
sns.set()
from pathflowai.losses import GeneralizedDiceLoss, FocalLoss
from apex import amp
from torch.nn import functional as F
import time, os
class MLP(nn.Module):
"""Multi-layer perceptron model.
Parameters
----------
n_input:int
Number input dimensions.
hidden_topology:list
List of hidden topology
dropout_p:float
Amount dropout.
n_outputs:int
Number outputs.
binary:bool
Binary output with sigmoid transform.
softmax:bool
Whether to apply softmax on output.
"""
def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True, softmax=False):
super(MLP,self).__init__()
self.topology = [n_input]+hidden_topology+[n_outputs]
layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)]
for layer in layers:
torch.nn.init.xavier_uniform_(layer.weight)
self.layers = [nn.Sequential(layer,nn.LeakyReLU(),nn.Dropout(p=dropout_p)) for layer in layers]
self.output_layer = nn.Linear(self.topology[-2],self.topology[-1])
torch.nn.init.xavier_uniform_(self.output_layer.weight)
if binary:
output_transform = nn.Sigmoid()
elif softmax:
output_transform = nn.Softmax()
else:
output_transform = nn.Dropout(p=0.)
self.layers.append(nn.Sequential(self.output_layer,output_transform))
self.mlp = nn.Sequential(*self.layers)
def forward(self,x):
return self.mlp(x)
class FixedSegmentationModule(nn.Module):
"""Special model modification for segmentation tasks. Gets output from some of the models' forward loops.
Parameters
----------
segnet:nn.Module
Segmentation network
"""
def __init__(self, segnet):
super(FixedSegmentationModule, self).__init__()
self.segnet=segnet
def forward(self, x):
"""Forward pass.
Parameters
----------
x:Tensor
Input
Returns
-------
Tensor
Output from model.
"""
return self.segnet(x)['out']
def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden=100, segmentation=False):
"""Generate a nn.Module for use.
Parameters
----------
pretrain:bool
Pretrain using ImageNet?
architecture:str
See model_training for list of all architectures you can train with.
num_classes:int
Number of classes to predict.
add_sigmoid:type
Add sigmoid non-linearity at end.
n_hidden:int
Number of hidden fully connected layers.
segmentation:bool
Whether segment task?
Returns
-------
nn.Module
Pytorch model.
"""
# to add: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/models/model_zoo.py
#architecture = 'resnet' + str(num_layers)
model = None
if architecture =='unet':
model = UNet(n_channels=3, n_classes=num_classes)
elif architecture =='unet2':
print('Deprecated for now, defaulting to UNET.')
model = UNet(n_channels=3, n_classes=num_classes)#UNet2(3,num_classes)
elif architecture == 'fast_scnn':
model = get_fast_scnn(num_classes)
elif architecture == 'nested_unet':
print('Nested UNET is deprecated for now, defaulting to UNET.')
model = UNet(n_channels=3, n_classes=num_classes)#NestedUNet(3, num_classes)
elif architecture.startswith('efficientnet'):
from efficientnet_pytorch import EfficientNet
if pretrain:
model = EfficientNet.from_pretrained(architecture, override_params=dict(num_classes=num_classes))
else:
model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes))
print(model)
elif architecture.startswith('sqnxt'):
from pytorchcv.model_provider import get_model as ptcv_get_model
model = ptcv_get_model(architecture, pretrained=pretrain)
num_ftrs=int(128*int(architecture.split('_')[-1][1]))
model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp
else:
#for pretrained on imagenet
model_names = [m for m in dir(models) if not m.startswith('__')]
segmentation_model_names = [m for m in dir(segmodels) if not m.startswith('__')]
if architecture in model_names:
model = getattr(models, architecture)(pretrained=pretrain)
if segmentation:
if architecture in segmentation_model_names:
model = getattr(segmodels, architecture)(pretrained=pretrain)
else:
model = UNet(n_channels=3, n_classes=num_classes)
if architecture.startswith('deeplab'):
model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1))
model = FixedSegmentationModule(model)
elif architecture.startswith('fcn'):
model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
model = FixedSegmentationModule(model)
elif architecture.startswith('resnet') or architecture.startswith('inception'):
num_ftrs = model.fc.in_features
#linear_layer = nn.Linear(num_ftrs, num_classes)
#torch.nn.init.xavier_uniform(linear_layer.weight)
model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else [])))
elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
num_ftrs = model.classifier[6].in_features
#linear_layer = nn.Linear(num_ftrs, num_classes)
#torch.nn.init.xavier_uniform(linear_layer.weight)
model.classifier[6] = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else [])))
return model
#@pysnooper.snoop("dice_loss.log")
def dice_loss(logits, true, eps=1e-7):
"""https://github.com/kevinzakka/pytorch-goodies
Computes the Sørensen–Dice loss.
Note that PyTorch optimizers minimize a loss. In this
case, we would like to maximize the dice loss so we
return the negated dice loss.
Args:
true: a tensor of shape [B, 1, H, W].
logits: a tensor of shape [B, C, H, W]. Corresponds to
the raw output or logits of the model.
eps: added to the denominator for numerical stability.
Returns:
dice_loss: the Sørensen–Dice loss.
"""
#true=true.long()
num_classes = logits.shape[1]
if num_classes == 1:
true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)]
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
true_1_hot_f = true_1_hot[:, 0:1, :, :]
true_1_hot_s = true_1_hot[:, 1:2, :, :]
true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
pos_prob = torch.sigmoid(logits)
neg_prob = 1 - pos_prob
probas = torch.cat([pos_prob, neg_prob], dim=1)
else:
true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
#print(true_1_hot.size())
true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
probas = F.softmax(logits, dim=1)
true_1_hot = true_1_hot.type(logits.type())
dims = (0,) + tuple(range(2, true.ndimension()))
intersection = torch.sum(probas * true_1_hot, dims)
cardinality = torch.sum(probas + true_1_hot, dims)
dice_loss = (2. * intersection / (cardinality + eps)).mean()
return (1 - dice_loss)
class ModelTrainer:
"""Trainer for the neural network model that wraps it into a scikit-learn like interface.
Parameters
----------
model:nn.Module
Deep learning pytorch model.
n_epoch:int
Number training epochs.
validation_dataloader:DataLoader
Dataloader of validation dataset.
optimizer_opts:dict
Options for optimizer.
scheduler_opts:dict
Options for learning rate scheduler.
loss_fn:str
String to call a particular loss function for model.
reduction:str
Mean or sum reduction of loss.
num_train_batches:int
Number of training batches for epoch.
"""
def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1, apex_opt_level="O2", checkpointing=False):
self.model = model
optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD}
loss_functions = {'bce':nn.BCEWithLogitsLoss(reduction=reduction), 'ce':nn.CrossEntropyLoss(reduction=reduction), 'mse':nn.MSELoss(reduction=reduction), 'nll':nn.NLLLoss(reduction=reduction), 'dice':dice_loss, 'focal':FocalLoss(num_class=2), 'gdl':GeneralizedDiceLoss(add_softmax=True)}
loss_functions['dice+ce']=(lambda y_pred, y_true: dice_loss(y_pred,y_true)+loss_functions['ce'](y_pred,y_true))
if 'name' not in list(optimizer_opts.keys()):
optimizer_opts['name']='adam'
self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts)
if torch.cuda.is_available():
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=apex_opt_level)
self.cuda=True
else:
self.cuda=False
self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts)
self.n_epoch = n_epoch
self.validation_dataloader = validation_dataloader
self.loss_fn = loss_functions[loss_fn]
self.loss_fn_name = loss_fn
self.bce=(self.loss_fn_name=='bce' or self.validation_dataloader.dataset.mt_bce)
self.sigmoid = nn.Sigmoid()
self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn])
self.num_train_batches = num_train_batches
self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn])
self.seg_out_class=seg_out_class
self.checkpointing=checkpointing
self.checkpoint_dir='./checkpoints'
if self.checkpointing:
os.makedirs(self.checkpoint_dir,exist_ok=True)
def save_model(self, model=None, epoch=0):
torch.save((model if isinstance(model,type(None)) else self.model).state_dict(),os.path.join(self.checkpoint_dir,f'checkpoint.{epoch}.pth'))
def calc_loss(self, y_pred, y_true):
"""Calculates loss supplied in init statement and modified by reweighting.
Parameters
----------
y_pred:tensor
Predictions.
y_true:tensor
True values.
Returns
-------
loss
"""
return self.loss_fn(y_pred, y_true)
def calc_val_loss(self, y_pred, y_true):
"""Calculates loss supplied in init statement on validation set.
Parameters
----------
y_pred:tensor
Predictions.
y_true:tensor
True values.
Returns
-------
val_loss
"""
return self.val_loss_fn(y_pred, y_true)
def reset_loss_fn(self):
"""Resets loss to original specified loss."""
self.loss_fn = self.original_loss_fn
def add_class_balance_loss(self, dataset, custom_weights=''):
"""Updates loss function to handle class imbalance by weighting inverse to class appearance.
Parameters
----------
dataset:DynamicImageDataset
Dataset to balance by.
"""
self.class_weights = dataset.get_class_weights() if not custom_weights else np.array(list(map(float,custom_weights.split(','))))
if custom_weights:
self.class_weights=self.class_weights/sum(self.class_weights)
print('Weights:',self.class_weights)
self.original_loss_fn = copy.deepcopy(self.loss_fn)
weight=torch.tensor(self.class_weights,dtype=torch.float)
if torch.cuda.is_available():
weight=weight.cuda()
if self.loss_fn_name=='ce':
self.loss_fn = nn.CrossEntropyLoss(weight=weight)
elif self.loss_fn_name=='nll':
self.loss_fn = nn.NLLLoss(weight=weight)
else: # modify below for multi-target
self.loss_fn = lambda y_pred,y_true: sum([self.class_weights[i]*self.original_loss_fn(y_pred[y_true==i],y_true[y_true==i]) if sum(y_true==i) else 0. for i in range(2)])
def calc_best_confusion(self, y_pred, y_true):
"""Calculate confusion matrix on validation set for classification/segmentation tasks, optimize threshold where positive.
Parameters
----------
y_pred:array
Predictions.
y_true:array
Ground truth.
Returns
-------
float
Optimized threshold to use on test set.
dataframe
Confusion matrix.
"""
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
threshold=thresholds[np.argmin(np.sum((np.array([0,1])-np.vstack((fpr, tpr)).T)**2,axis=1)**.5)]
y_pred = (y_pred>threshold).astype(int)
return threshold, pd.DataFrame(confusion_matrix(y_true,y_pred),index=['F','T'],columns=['-','+']).iloc[::-1,::-1].T
def loss_backward(self,loss):
"""Backprop using mixed precision for added speed boost.
Parameters
----------
loss:loss
Torch loss calculated.
"""
if self.cuda:
with amp.scale_loss(loss,self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# @pysnooper.snoop('train_loop.log')
def train_loop(self, epoch, train_dataloader):
"""One training epoch, calculate predictions, loss, backpropagate.
Parameters
----------
epoch:int
Current epoch.
train_dataloader:DataLoader
Training data.
Returns
-------
float
Training loss for epoch
"""
self.model.train(True)
running_loss = 0.
n_batch = len(train_dataloader.dataset)//train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches
for i, batch in enumerate(train_dataloader):
starttime=time.time()
if i == n_batch:
break
X = Variable(batch[0], requires_grad=True)
y_true = Variable(batch[1])
if not train_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1:
y_true=y_true.argmax(1).long()
if train_dataloader.dataset.segmentation and self.loss_fn_name!='dice':
y_true=y_true.squeeze(1)
if torch.cuda.is_available():
X = X.cuda()
y_true=y_true.cuda()
y_pred = self.model(X)
#sizes=(y_pred.size(),y_true.size())
#print(y_true)
loss = self.calc_loss(y_pred,y_true)
train_loss=loss.item()
running_loss += train_loss
self.optimizer.zero_grad()
self.loss_backward(loss)#loss.backward()
self.optimizer.step()
endtime=time.time()
print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,i,n_batch,round(endtime-starttime,3),train_loss))
self.scheduler.step()
running_loss/=n_batch
return running_loss
def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True):
"""Calculate loss over validation set.
Parameters
----------
epoch:int
Current epoch.
val_dataloader:DataLoader
Validation iterator.
print_val_confusion:bool
Calculate confusion matrix and plot.
save_predictions:int
Print validation results.
Returns
-------
float
Validation loss for epoch.
"""
self.model.train(False)
n_batch = len(val_dataloader.dataset)//val_dataloader.batch_size
running_loss = 0.
Y = {'pred':[],'true':[]}
with torch.no_grad():
for i, batch in enumerate(val_dataloader):
X = Variable(batch[0],requires_grad=False)
y_true = Variable(batch[1])
if not val_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1:
y_true=y_true.argmax(1).long()
if val_dataloader.dataset.segmentation and self.loss_fn_name!='dice':
y_true=y_true.squeeze(1)
if torch.cuda.is_available():
X = X.cuda()
y_true=y_true.cuda()
y_pred = self.model(X)
if save_predictions:
if val_dataloader.dataset.segmentation:
Y['true'].append(torch.flatten(y_true if not val_dataloader.dataset.gdl else y_true).detach().cpu().numpy().astype(int).flatten()) # .argmax(axis=1)
Y['pred'].append((y_pred.detach().cpu().numpy().argmax(axis=1)).astype(int).flatten())
else:
Y['true'].append(y_true.detach().cpu().numpy().astype(int).flatten())
y_pred_numpy=((y_pred if not self.bce else self.sigmoid(y_pred)).detach().cpu().numpy()).astype(float)
if len(y_pred_numpy)>1 and y_pred_numpy.shape[1]>1 and not val_dataloader.dataset.mt_bce:
y_pred_numpy=y_pred_numpy.argmax(axis=1)
Y['pred'].append(y_pred_numpy.flatten())
loss = self.calc_val_loss(y_pred,y_true)
val_loss=loss.item()
running_loss += val_loss
print("Epoch {}[{}/{}] Val Loss:{}".format(epoch,i,n_batch,val_loss))
if print_val_confusion and save_predictions:
y_pred,y_true = np.hstack(Y['pred']),np.hstack(Y['true'])
if not val_dataloader.dataset.segmentation:
if self.loss_fn_name in ['bce','mse'] and not val_dataloader.dataset.mt_bce:
threshold, best_confusion = self.calc_best_confusion(y_pred,y_true)
print("Epoch {} Val Confusion, Threshold {}:".format(epoch,threshold))
print(best_confusion)
y_true = y_true.astype(int)
y_pred = (y_pred>=threshold).astype(int)
elif val_dataloader.dataset.mt_bce:
n_targets = len(val_dataloader.dataset.targets)
y_pred=y_pred[y_true>0]
y_true=y_true[y_true>0]
y_true=y_true[np.isnan(y_pred)==False]
y_pred=y_pred[np.isnan(y_pred)==False]
if 0 and n_targets > 1:
n_row=len(y_true)/n_targets
y_pred=y_pred.reshape(int(n_row),n_targets)
y_true=y_true.reshape(int(n_row),n_targets)
print("Epoch {} Val Regression, R2 Score {}".format(epoch, str(r2_score(y_true, y_pred))))
else:
print(classification_report(y_true,y_pred))
running_loss/=n_batch
return running_loss
#@pysnooper.snoop("test_loop.log")
def test_loop(self, test_dataloader):
"""Calculate final predictions on loss.
Parameters
----------
test_dataloader:DataLoader
Test dataset.
Returns
-------
array
Predictions or embeddings.
"""
#self.model.train(False) KEEP DROPOUT? and BATCH NORM??
y_pred = []
running_loss = 0.
with torch.no_grad():
for i, (X,y_test) in enumerate(test_dataloader):
#X = Variable(batch[0],requires_grad=False)
if torch.cuda.is_available():
X = X.cuda()
if test_dataloader.dataset.segmentation:
prediction=self.model(X).detach().cpu().numpy()
if self.seg_out_class>=0:
prediction=prediction[:,self.seg_out_class,...]
else:
prediction=prediction.argmax(axis=1).astype(int)
pred_size=prediction.shape#size()
#pred_mean=prediction[0].mean(axis=0)
y_pred.append(prediction)
else:
prediction=self.model(X)
if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce):
prediction=self.sigmoid(prediction)
elif test_dataloader.dataset.classify_annotations:
prediction=F.softmax(prediction,dim=1)
y_pred.append(prediction.detach().cpu().numpy())
y_pred = np.concatenate(y_pred,axis=0)#torch.cat(y_pred,0)
return y_pred
def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=False, plot_save_file=None, print_val_confusion=True, save_val_predictions=True):
"""Fits the segmentation or classification model to the patches, saving the model with the lowest validation score.
Parameters
----------
train_dataloader:DataLoader
Training dataset.
verbose:bool
Print training and validation loss?
print_every:int
Number of epochs until print?
save_model:bool
Whether to save model when reaching lowest validation loss.
plot_training_curves:bool
Plot training curves over epochs.
plot_save_file:str
File to save training curves.
print_val_confusion:bool
Print validation confusion matrix.
save_val_predictions:bool
Print validation results.
Returns
-------
self
Trainer.
float
Minimum val loss.
int
Best validation epoch with lowest loss.
"""
# choose model with best f1
self.train_losses = []
self.val_losses = []
for epoch in range(self.n_epoch):
start_time=time.time()
train_loss = self.train_loop(epoch,train_dataloader)
current_time=time.time()
train_time=current_time-start_time
self.train_losses.append(train_loss)
val_loss = self.val_loop(epoch,self.validation_dataloader, print_val_confusion=print_val_confusion, save_predictions=save_val_predictions)
val_time=time.time()-current_time
self.val_losses.append(val_loss)
if verbose and not (epoch % print_every):
if plot_training_curves:
self.plot_train_val_curves(plot_save_file)
print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(epoch,train_loss,val_loss,train_time,val_time))
if val_loss <= min(self.val_losses) and save_model:
min_val_loss = val_loss
best_epoch = epoch
best_model = copy.deepcopy(self.model)
if self.checkpointing:
self.save_model(best_model,epoch)
if save_model:
self.model = best_model
return self, min_val_loss, best_epoch
def plot_train_val_curves(self, save_file=None):
"""Plots training and validation curves.
Parameters
----------
save_file:str
File to save to.
"""
plt.figure()
sns.lineplot('epoch','value',hue='variable',
data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)),self.train_losses,self.val_losses)).T,
columns=['epoch','train','val']).melt(id_vars=['epoch'],value_vars=['train','val']))
if save_file is not None:
plt.savefig(save_file, dpi=300)
def predict(self, test_dataloader):
"""Make classification segmentation predictions on testing data.
Parameters
----------
test_dataloader:DataLoader
Test data.
Returns
-------
array
Predictions.
"""
y_pred = self.test_loop(test_dataloader)
return y_pred
def fit_predict(self, train_dataloader, test_dataloader):
"""Fit model to training data and make classification segmentation predictions on testing data.
Parameters
----------
train_dataloader:DataLoader
Train data.
test_dataloader:DataLoader
Test data.
Returns
-------
array
Predictions.
"""
return self.fit(train_dataloader)[0].predict(test_dataloader)
def return_model(self):
"""Returns pytorch model.
"""
return self.model