--- a +++ b/pathflowai/models.py @@ -0,0 +1,654 @@ +""" +models.py +======================= +Houses all of the PyTorch models to access and the corresponding Scikit-Learn like model trainer. +""" +from pathflowai.unet import UNet +# from pathflowai.unet2 import NestedUNet +# from pathflowai.unet4 import UNetSmall as UNet2 +from pathflowai.fast_scnn import get_fast_scnn +import torch +import torchvision +from torchvision import models +from torchvision.models import segmentation as segmodels +from torch import nn +from torch.nn import functional as F +import pandas as pd, numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import seaborn as sns +from pathflowai.schedulers import * +import pysnooper +from torch.autograd import Variable +import copy +from sklearn.metrics import roc_curve, confusion_matrix, classification_report, r2_score +sns.set() +from pathflowai.losses import GeneralizedDiceLoss, FocalLoss +from apex import amp +from torch.nn import functional as F +import time, os + +class MLP(nn.Module): + """Multi-layer perceptron model. + + Parameters + ---------- + n_input:int + Number input dimensions. + hidden_topology:list + List of hidden topology + dropout_p:float + Amount dropout. + n_outputs:int + Number outputs. + binary:bool + Binary output with sigmoid transform. + softmax:bool + Whether to apply softmax on output. + + """ + def __init__(self, n_input, hidden_topology, dropout_p, n_outputs=1, binary=True, softmax=False): + super(MLP,self).__init__() + self.topology = [n_input]+hidden_topology+[n_outputs] + layers = [nn.Linear(self.topology[i],self.topology[i+1]) for i in range(len(self.topology)-2)] + for layer in layers: + torch.nn.init.xavier_uniform_(layer.weight) + self.layers = [nn.Sequential(layer,nn.LeakyReLU(),nn.Dropout(p=dropout_p)) for layer in layers] + self.output_layer = nn.Linear(self.topology[-2],self.topology[-1]) + torch.nn.init.xavier_uniform_(self.output_layer.weight) + if binary: + output_transform = nn.Sigmoid() + elif softmax: + output_transform = nn.Softmax() + else: + output_transform = nn.Dropout(p=0.) + self.layers.append(nn.Sequential(self.output_layer,output_transform)) + self.mlp = nn.Sequential(*self.layers) + + def forward(self,x): + return self.mlp(x) + +class FixedSegmentationModule(nn.Module): + """Special model modification for segmentation tasks. Gets output from some of the models' forward loops. + + Parameters + ---------- + segnet:nn.Module + Segmentation network + """ + def __init__(self, segnet): + super(FixedSegmentationModule, self).__init__() + self.segnet=segnet + + def forward(self, x): + """Forward pass. + + Parameters + ---------- + x:Tensor + Input + + Returns + ------- + Tensor + Output from model. + + """ + return self.segnet(x)['out'] + +def generate_model(pretrain,architecture,num_classes, add_sigmoid=True, n_hidden=100, segmentation=False): + """Generate a nn.Module for use. + + Parameters + ---------- + pretrain:bool + Pretrain using ImageNet? + architecture:str + See model_training for list of all architectures you can train with. + num_classes:int + Number of classes to predict. + add_sigmoid:type + Add sigmoid non-linearity at end. + n_hidden:int + Number of hidden fully connected layers. + segmentation:bool + Whether segment task? + + Returns + ------- + nn.Module + Pytorch model. + + """ + # to add: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/models/model_zoo.py + #architecture = 'resnet' + str(num_layers) + model = None + + if architecture =='unet': + model = UNet(n_channels=3, n_classes=num_classes) + elif architecture =='unet2': + print('Deprecated for now, defaulting to UNET.') + model = UNet(n_channels=3, n_classes=num_classes)#UNet2(3,num_classes) + elif architecture == 'fast_scnn': + model = get_fast_scnn(num_classes) + elif architecture == 'nested_unet': + print('Nested UNET is deprecated for now, defaulting to UNET.') + model = UNet(n_channels=3, n_classes=num_classes)#NestedUNet(3, num_classes) + elif architecture.startswith('efficientnet'): + from efficientnet_pytorch import EfficientNet + if pretrain: + model = EfficientNet.from_pretrained(architecture, override_params=dict(num_classes=num_classes)) + else: + model = EfficientNet.from_name(architecture, override_params=dict(num_classes=num_classes)) + print(model) + elif architecture.startswith('sqnxt'): + from pytorchcv.model_provider import get_model as ptcv_get_model + model = ptcv_get_model(architecture, pretrained=pretrain) + num_ftrs=int(128*int(architecture.split('_')[-1][1])) + model.output=MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp + else: + #for pretrained on imagenet + model_names = [m for m in dir(models) if not m.startswith('__')] + segmentation_model_names = [m for m in dir(segmodels) if not m.startswith('__')] + if architecture in model_names: + model = getattr(models, architecture)(pretrained=pretrain) + if segmentation: + if architecture in segmentation_model_names: + model = getattr(segmodels, architecture)(pretrained=pretrain) + else: + model = UNet(n_channels=3, n_classes=num_classes) + if architecture.startswith('deeplab'): + model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=(1, 1), stride=(1, 1)) + model = FixedSegmentationModule(model) + elif architecture.startswith('fcn'): + model.classifier[4] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + model = FixedSegmentationModule(model) + elif architecture.startswith('resnet') or architecture.startswith('inception'): + num_ftrs = model.fc.in_features + #linear_layer = nn.Linear(num_ftrs, num_classes) + #torch.nn.init.xavier_uniform(linear_layer.weight) + model.fc = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else []))) + elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'): + num_ftrs = model.classifier[6].in_features + #linear_layer = nn.Linear(num_ftrs, num_classes) + #torch.nn.init.xavier_uniform(linear_layer.weight) + model.classifier[6] = MLP(num_ftrs, [1000], dropout_p=0., n_outputs=num_classes, binary=add_sigmoid, softmax=False).mlp#nn.Sequential(*([linear_layer]+([nn.Sigmoid()] if (add_sigmoid) else []))) + return model + +#@pysnooper.snoop("dice_loss.log") +def dice_loss(logits, true, eps=1e-7): + """https://github.com/kevinzakka/pytorch-goodies + Computes the Sørensen–Dice loss. + + Note that PyTorch optimizers minimize a loss. In this + case, we would like to maximize the dice loss so we + return the negated dice loss. + + Args: + true: a tensor of shape [B, 1, H, W]. + logits: a tensor of shape [B, C, H, W]. Corresponds to + the raw output or logits of the model. + eps: added to the denominator for numerical stability. + + Returns: + dice_loss: the Sørensen–Dice loss. + """ + #true=true.long() + num_classes = logits.shape[1] + if num_classes == 1: + true_1_hot = torch.eye(num_classes + 1)[true.squeeze(1)] + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + true_1_hot_f = true_1_hot[:, 0:1, :, :] + true_1_hot_s = true_1_hot[:, 1:2, :, :] + true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1) + pos_prob = torch.sigmoid(logits) + neg_prob = 1 - pos_prob + probas = torch.cat([pos_prob, neg_prob], dim=1) + else: + true_1_hot = torch.eye(num_classes)[true.squeeze(1)] + #print(true_1_hot.size()) + true_1_hot = true_1_hot.permute(0, 3, 1, 2).float() + probas = F.softmax(logits, dim=1) + true_1_hot = true_1_hot.type(logits.type()) + dims = (0,) + tuple(range(2, true.ndimension())) + intersection = torch.sum(probas * true_1_hot, dims) + cardinality = torch.sum(probas + true_1_hot, dims) + dice_loss = (2. * intersection / (cardinality + eps)).mean() + return (1 - dice_loss) + +class ModelTrainer: + """Trainer for the neural network model that wraps it into a scikit-learn like interface. + + Parameters + ---------- + model:nn.Module + Deep learning pytorch model. + n_epoch:int + Number training epochs. + validation_dataloader:DataLoader + Dataloader of validation dataset. + optimizer_opts:dict + Options for optimizer. + scheduler_opts:dict + Options for learning rate scheduler. + loss_fn:str + String to call a particular loss function for model. + reduction:str + Mean or sum reduction of loss. + num_train_batches:int + Number of training batches for epoch. + """ + def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam',lr=1e-3,weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts',lr_scheduler_decay=0.5,T_max=10,eta_min=5e-8,T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, seg_out_class=-1, apex_opt_level="O2", checkpointing=False): + + self.model = model + optimizers = {'adam':torch.optim.Adam, 'sgd':torch.optim.SGD} + loss_functions = {'bce':nn.BCEWithLogitsLoss(reduction=reduction), 'ce':nn.CrossEntropyLoss(reduction=reduction), 'mse':nn.MSELoss(reduction=reduction), 'nll':nn.NLLLoss(reduction=reduction), 'dice':dice_loss, 'focal':FocalLoss(num_class=2), 'gdl':GeneralizedDiceLoss(add_softmax=True)} + loss_functions['dice+ce']=(lambda y_pred, y_true: dice_loss(y_pred,y_true)+loss_functions['ce'](y_pred,y_true)) + if 'name' not in list(optimizer_opts.keys()): + optimizer_opts['name']='adam' + self.optimizer = optimizers[optimizer_opts.pop('name')](self.model.parameters(),**optimizer_opts) + if torch.cuda.is_available(): + self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=apex_opt_level) + self.cuda=True + else: + self.cuda=False + self.scheduler = Scheduler(optimizer=self.optimizer,opts=scheduler_opts) + self.n_epoch = n_epoch + self.validation_dataloader = validation_dataloader + self.loss_fn = loss_functions[loss_fn] + self.loss_fn_name = loss_fn + self.bce=(self.loss_fn_name=='bce' or self.validation_dataloader.dataset.mt_bce) + self.sigmoid = nn.Sigmoid() + self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn]) + self.num_train_batches = num_train_batches + self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn]) + self.seg_out_class=seg_out_class + self.checkpointing=checkpointing + self.checkpoint_dir='./checkpoints' + if self.checkpointing: + os.makedirs(self.checkpoint_dir,exist_ok=True) + + def save_model(self, model=None, epoch=0): + torch.save((model if isinstance(model,type(None)) else self.model).state_dict(),os.path.join(self.checkpoint_dir,f'checkpoint.{epoch}.pth')) + + def calc_loss(self, y_pred, y_true): + """Calculates loss supplied in init statement and modified by reweighting. + + Parameters + ---------- + y_pred:tensor + Predictions. + y_true:tensor + True values. + + Returns + ------- + loss + + """ + + return self.loss_fn(y_pred, y_true) + + def calc_val_loss(self, y_pred, y_true): + """Calculates loss supplied in init statement on validation set. + + Parameters + ---------- + y_pred:tensor + Predictions. + y_true:tensor + True values. + + Returns + ------- + val_loss + + """ + + return self.val_loss_fn(y_pred, y_true) + + def reset_loss_fn(self): + """Resets loss to original specified loss.""" + self.loss_fn = self.original_loss_fn + + def add_class_balance_loss(self, dataset, custom_weights=''): + """Updates loss function to handle class imbalance by weighting inverse to class appearance. + + Parameters + ---------- + dataset:DynamicImageDataset + Dataset to balance by. + + """ + self.class_weights = dataset.get_class_weights() if not custom_weights else np.array(list(map(float,custom_weights.split(',')))) + if custom_weights: + self.class_weights=self.class_weights/sum(self.class_weights) + print('Weights:',self.class_weights) + self.original_loss_fn = copy.deepcopy(self.loss_fn) + weight=torch.tensor(self.class_weights,dtype=torch.float) + if torch.cuda.is_available(): + weight=weight.cuda() + if self.loss_fn_name=='ce': + self.loss_fn = nn.CrossEntropyLoss(weight=weight) + elif self.loss_fn_name=='nll': + self.loss_fn = nn.NLLLoss(weight=weight) + else: # modify below for multi-target + self.loss_fn = lambda y_pred,y_true: sum([self.class_weights[i]*self.original_loss_fn(y_pred[y_true==i],y_true[y_true==i]) if sum(y_true==i) else 0. for i in range(2)]) + + def calc_best_confusion(self, y_pred, y_true): + """Calculate confusion matrix on validation set for classification/segmentation tasks, optimize threshold where positive. + + Parameters + ---------- + y_pred:array + Predictions. + y_true:array + Ground truth. + + Returns + ------- + float + Optimized threshold to use on test set. + dataframe + Confusion matrix. + + """ + fpr, tpr, thresholds = roc_curve(y_true, y_pred) + threshold=thresholds[np.argmin(np.sum((np.array([0,1])-np.vstack((fpr, tpr)).T)**2,axis=1)**.5)] + y_pred = (y_pred>threshold).astype(int) + return threshold, pd.DataFrame(confusion_matrix(y_true,y_pred),index=['F','T'],columns=['-','+']).iloc[::-1,::-1].T + + def loss_backward(self,loss): + """Backprop using mixed precision for added speed boost. + + Parameters + ---------- + loss:loss + Torch loss calculated. + + """ + if self.cuda: + with amp.scale_loss(loss,self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + # @pysnooper.snoop('train_loop.log') + def train_loop(self, epoch, train_dataloader): + """One training epoch, calculate predictions, loss, backpropagate. + + Parameters + ---------- + epoch:int + Current epoch. + train_dataloader:DataLoader + Training data. + + Returns + ------- + float + Training loss for epoch + + """ + self.model.train(True) + running_loss = 0. + n_batch = len(train_dataloader.dataset)//train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches + for i, batch in enumerate(train_dataloader): + starttime=time.time() + if i == n_batch: + break + X = Variable(batch[0], requires_grad=True) + y_true = Variable(batch[1]) + if not train_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1: + y_true=y_true.argmax(1).long() + if train_dataloader.dataset.segmentation and self.loss_fn_name!='dice': + y_true=y_true.squeeze(1) + if torch.cuda.is_available(): + X = X.cuda() + y_true=y_true.cuda() + y_pred = self.model(X) + #sizes=(y_pred.size(),y_true.size()) + #print(y_true) + loss = self.calc_loss(y_pred,y_true) + train_loss=loss.item() + running_loss += train_loss + self.optimizer.zero_grad() + self.loss_backward(loss)#loss.backward() + self.optimizer.step() + endtime=time.time() + print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,i,n_batch,round(endtime-starttime,3),train_loss)) + self.scheduler.step() + running_loss/=n_batch + return running_loss + + def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True): + """Calculate loss over validation set. + + Parameters + ---------- + epoch:int + Current epoch. + val_dataloader:DataLoader + Validation iterator. + print_val_confusion:bool + Calculate confusion matrix and plot. + save_predictions:int + Print validation results. + + Returns + ------- + float + Validation loss for epoch. + """ + self.model.train(False) + n_batch = len(val_dataloader.dataset)//val_dataloader.batch_size + running_loss = 0. + Y = {'pred':[],'true':[]} + with torch.no_grad(): + for i, batch in enumerate(val_dataloader): + X = Variable(batch[0],requires_grad=False) + y_true = Variable(batch[1]) + if not val_dataloader.dataset.segmentation and self.loss_fn_name=='ce' and y_true.shape[1]>1: + y_true=y_true.argmax(1).long() + if val_dataloader.dataset.segmentation and self.loss_fn_name!='dice': + y_true=y_true.squeeze(1) + if torch.cuda.is_available(): + X = X.cuda() + y_true=y_true.cuda() + y_pred = self.model(X) + if save_predictions: + if val_dataloader.dataset.segmentation: + Y['true'].append(torch.flatten(y_true if not val_dataloader.dataset.gdl else y_true).detach().cpu().numpy().astype(int).flatten()) # .argmax(axis=1) + Y['pred'].append((y_pred.detach().cpu().numpy().argmax(axis=1)).astype(int).flatten()) + else: + Y['true'].append(y_true.detach().cpu().numpy().astype(int).flatten()) + y_pred_numpy=((y_pred if not self.bce else self.sigmoid(y_pred)).detach().cpu().numpy()).astype(float) + if len(y_pred_numpy)>1 and y_pred_numpy.shape[1]>1 and not val_dataloader.dataset.mt_bce: + y_pred_numpy=y_pred_numpy.argmax(axis=1) + Y['pred'].append(y_pred_numpy.flatten()) + loss = self.calc_val_loss(y_pred,y_true) + val_loss=loss.item() + running_loss += val_loss + print("Epoch {}[{}/{}] Val Loss:{}".format(epoch,i,n_batch,val_loss)) + if print_val_confusion and save_predictions: + y_pred,y_true = np.hstack(Y['pred']),np.hstack(Y['true']) + if not val_dataloader.dataset.segmentation: + if self.loss_fn_name in ['bce','mse'] and not val_dataloader.dataset.mt_bce: + threshold, best_confusion = self.calc_best_confusion(y_pred,y_true) + print("Epoch {} Val Confusion, Threshold {}:".format(epoch,threshold)) + print(best_confusion) + y_true = y_true.astype(int) + y_pred = (y_pred>=threshold).astype(int) + elif val_dataloader.dataset.mt_bce: + n_targets = len(val_dataloader.dataset.targets) + y_pred=y_pred[y_true>0] + y_true=y_true[y_true>0] + y_true=y_true[np.isnan(y_pred)==False] + y_pred=y_pred[np.isnan(y_pred)==False] + if 0 and n_targets > 1: + n_row=len(y_true)/n_targets + y_pred=y_pred.reshape(int(n_row),n_targets) + y_true=y_true.reshape(int(n_row),n_targets) + print("Epoch {} Val Regression, R2 Score {}".format(epoch, str(r2_score(y_true, y_pred)))) + else: + print(classification_report(y_true,y_pred)) + + running_loss/=n_batch + return running_loss + + #@pysnooper.snoop("test_loop.log") + def test_loop(self, test_dataloader): + """Calculate final predictions on loss. + + Parameters + ---------- + test_dataloader:DataLoader + Test dataset. + + Returns + ------- + array + Predictions or embeddings. + """ + #self.model.train(False) KEEP DROPOUT? and BATCH NORM?? + y_pred = [] + running_loss = 0. + with torch.no_grad(): + for i, (X,y_test) in enumerate(test_dataloader): + #X = Variable(batch[0],requires_grad=False) + if torch.cuda.is_available(): + X = X.cuda() + if test_dataloader.dataset.segmentation: + prediction=self.model(X).detach().cpu().numpy() + if self.seg_out_class>=0: + prediction=prediction[:,self.seg_out_class,...] + else: + prediction=prediction.argmax(axis=1).astype(int) + pred_size=prediction.shape#size() + #pred_mean=prediction[0].mean(axis=0) + y_pred.append(prediction) + else: + prediction=self.model(X) + if self.loss_fn_name != 'mse' and ((len(test_dataloader.dataset.targets)-1) or self.bce): + prediction=self.sigmoid(prediction) + elif test_dataloader.dataset.classify_annotations: + prediction=F.softmax(prediction,dim=1) + y_pred.append(prediction.detach().cpu().numpy()) + y_pred = np.concatenate(y_pred,axis=0)#torch.cat(y_pred,0) + + return y_pred + + def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=False, plot_save_file=None, print_val_confusion=True, save_val_predictions=True): + """Fits the segmentation or classification model to the patches, saving the model with the lowest validation score. + + Parameters + ---------- + train_dataloader:DataLoader + Training dataset. + verbose:bool + Print training and validation loss? + print_every:int + Number of epochs until print? + save_model:bool + Whether to save model when reaching lowest validation loss. + plot_training_curves:bool + Plot training curves over epochs. + plot_save_file:str + File to save training curves. + print_val_confusion:bool + Print validation confusion matrix. + save_val_predictions:bool + Print validation results. + + Returns + ------- + self + Trainer. + float + Minimum val loss. + int + Best validation epoch with lowest loss. + + """ + # choose model with best f1 + self.train_losses = [] + self.val_losses = [] + for epoch in range(self.n_epoch): + start_time=time.time() + train_loss = self.train_loop(epoch,train_dataloader) + current_time=time.time() + train_time=current_time-start_time + self.train_losses.append(train_loss) + val_loss = self.val_loop(epoch,self.validation_dataloader, print_val_confusion=print_val_confusion, save_predictions=save_val_predictions) + val_time=time.time()-current_time + self.val_losses.append(val_loss) + if verbose and not (epoch % print_every): + if plot_training_curves: + self.plot_train_val_curves(plot_save_file) + print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(epoch,train_loss,val_loss,train_time,val_time)) + if val_loss <= min(self.val_losses) and save_model: + min_val_loss = val_loss + best_epoch = epoch + best_model = copy.deepcopy(self.model) + if self.checkpointing: + self.save_model(best_model,epoch) + if save_model: + self.model = best_model + return self, min_val_loss, best_epoch + + def plot_train_val_curves(self, save_file=None): + """Plots training and validation curves. + + Parameters + ---------- + save_file:str + File to save to. + + """ + plt.figure() + sns.lineplot('epoch','value',hue='variable', + data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)),self.train_losses,self.val_losses)).T, + columns=['epoch','train','val']).melt(id_vars=['epoch'],value_vars=['train','val'])) + if save_file is not None: + plt.savefig(save_file, dpi=300) + + def predict(self, test_dataloader): + """Make classification segmentation predictions on testing data. + + Parameters + ---------- + test_dataloader:DataLoader + Test data. + + Returns + ------- + array + Predictions. + + """ + y_pred = self.test_loop(test_dataloader) + return y_pred + + def fit_predict(self, train_dataloader, test_dataloader): + """Fit model to training data and make classification segmentation predictions on testing data. + + Parameters + ---------- + train_dataloader:DataLoader + Train data. + test_dataloader:DataLoader + Test data. + + Returns + ------- + array + Predictions. + + """ + return self.fit(train_dataloader)[0].predict(test_dataloader) + + def return_model(self): + """Returns pytorch model. + """ + return self.model