--- a +++ b/run_experiment.py @@ -0,0 +1,330 @@ +#run_experiment.py +#Copyright (c) 2020 Rachel Lea Ballantyne Draelos + +#MIT License + +#Permission is hereby granted, free of charge, to any person obtaining a copy +#of this software and associated documentation files (the "Software"), to deal +#in the Software without restriction, including without limitation the rights +#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +#copies of the Software, and to permit persons to whom the Software is +#furnished to do so, subject to the following conditions: + +#The above copyright notice and this permission notice shall be included in all +#copies or substantial portions of the Software. + +#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +#SOFTWARE + +import os +import timeit +import datetime +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset, DataLoader +import torch, torch.nn as nn, torch.nn.functional as F +import torchvision +from torchvision import transforms, models, utils + +import evaluate +from load_dataset import custom_datasets + +#Set seeds +np.random.seed(0) +torch.manual_seed(0) +torch.cuda.manual_seed(0) +torch.cuda.manual_seed_all(0) + +class DukeCTModel(object): + def __init__(self, descriptor, custom_net, custom_net_args, + loss, loss_args, num_epochs, patience, batch_size, device, data_parallel, + use_test_set, task, old_params_dir, dataset_class, dataset_args): + """Variables: + <descriptor>: string describing the experiment + <custom_net>: class defining a model + <custom_net_args>: dictionary where keys correspond to custom net + input arguments, and values are the desired values + <loss>: 'bce' for binary cross entropy + <loss_args>: arguments to pass to the loss function if any + <num_epochs>: int for the maximum number of epochs to train + <patience>: number of epochs for which loss must fail to improve to + cause early stopping + <batch_size>: int for number of examples per batch + <device>: int specifying which device to use, or 'all' for all devices + <data_parallel>: if True then parallelize across available GPUs. + <use_test_set>: if True, then run model on the test set. If False, use + only the training and validation sets. + <task>: + 'train_eval': train and evaluate a new model. 'evaluate' will + always imply use of the validation set. if <use_test_set> is + True, then 'evaluate' also includes calculation of test set + performance for the best validation epoch. + 'predict_on_test': load a trained model and make predictions on + the test set using that model. + <old_params_dir>: this is only needed if <task>=='predict_on_test'. This + is the path to the parameters that will be loaded in to the model. + <dataset_class>: CT Dataset class for preprocessing the data + <dataset_args>: arguments for the dataset class specifying how + the data should be prepared.""" + self.descriptor = descriptor + self.set_up_results_dirs() + self.custom_net = custom_net + self.custom_net_args = custom_net_args + self.loss = loss + self.loss_args = loss_args + self.num_epochs = num_epochs + self.batch_size = batch_size + print('self.batch_size=',self.batch_size) + #num_workers is number of threads to use for data loading + self.num_workers = int(batch_size*4) #batch_size 1 = num_workers 4. batch_size 2 = num workers 8. batch_size 4 = num_workers 16. + print('self.num_workers=',self.num_workers) + if self.num_workers == 1: + print('Warning: Using only one worker will slow down data loading') + + #Set Device and Data Parallelism + if device in [0,1,2,3]: #i.e. if a GPU number was specified: + self.device = torch.device('cuda:'+str(device)) + print('using device:',str(self.device),'\ndescriptor: ',self.descriptor) + elif device == 'all': + self.device = torch.device('cuda') + self.data_parallel = data_parallel + if self.data_parallel: + assert device == 'all' #use all devices when running data parallel + + #Set Task + self.use_test_set = use_test_set + self.task = task + assert self.task in ['train_eval','predict_on_test'] + if self.task == 'predict_on_test': + #overwrite the params dir that was created in the call to + #set_up_results_dirs() with the dir you want to load from + self.params_dir = old_params_dir + + #Data and Labels + self.CTDatasetClass = dataset_class + self.dataset_args = dataset_args + #Get label meanings, a list of descriptive strings (list elements must + #be strings found in the column headers of the labels file) + self.set_up_label_meanings(self.dataset_args['label_meanings']) + if self.task == 'train_eval': + self.dataset_train = self.CTDatasetClass(setname = 'train', **self.dataset_args) + self.dataset_valid = self.CTDatasetClass(setname = 'valid', **self.dataset_args) + if self.use_test_set: + self.dataset_test = self.CTDatasetClass(setname = 'test', **self.dataset_args) + + #Tracking losses and evaluation results + self.train_loss = np.zeros((self.num_epochs)) + self.valid_loss = np.zeros((self.num_epochs)) + self.eval_results_valid, self.eval_results_test = evaluate.initialize_evaluation_dfs(self.label_meanings, self.num_epochs) + + #For early stopping + self.initial_patience = patience + self.patience_remaining = patience + self.best_valid_epoch = 0 + self.min_val_loss = np.inf + + #Run everything + self.run_model() + + ### Methods ### + def set_up_label_meanings(self,label_meanings): + if label_meanings == 'all': #get full list of all available labels + temp = custom_datasets.read_in_labels(self.dataset_args['label_type_ld'], 'valid') + self.label_meanings = temp.columns.values.tolist() + else: #use the label meanings that were passed in + self.label_meanings = label_meanings + print('label meanings ('+str(len(self.label_meanings))+' labels total):',self.label_meanings) + + def set_up_results_dirs(self): + if not os.path.isdir('results'): + os.mkdir('results') + self.results_dir = os.path.join('results',datetime.datetime.today().strftime('%Y-%m-%d')+'_'+self.descriptor) + if not os.path.isdir(self.results_dir): + os.mkdir(self.results_dir) + self.params_dir = os.path.join(self.results_dir,'params') + if not os.path.isdir(self.params_dir): + os.mkdir(self.params_dir) + self.backup_dir = os.path.join(self.results_dir,'backup') + if not os.path.isdir(self.backup_dir): + os.mkdir(self.backup_dir) + + def run_model(self): + if self.data_parallel: + self.model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device) + else: + self.model = self.custom_net(**self.custom_net_args).to(self.device) + self.sigmoid = torch.nn.Sigmoid() + self.set_up_loss_function() + + momentum = 0.99 + print('Running with optimizer lr=1e-3, momentum='+str(round(momentum,2))+' and weight_decay=1e-7') + self.optimizer = torch.optim.SGD(self.model.parameters(), lr = 1e-3, momentum=momentum, weight_decay=1e-7) + + train_dataloader = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers = self.num_workers) + valid_dataloader = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers) + + if self.task == 'train_eval': + for epoch in range(self.num_epochs): + t0 = timeit.default_timer() + self.train(train_dataloader, epoch) + self.valid(valid_dataloader, epoch) + self.save_evals(epoch) + if self.patience_remaining <= 0: + print('No more patience (',self.initial_patience,') left at epoch',epoch) + print('--> Implementing early stopping. Best epoch was:',self.best_valid_epoch) + break + t1 = timeit.default_timer() + self.back_up_model_every_ten(epoch) + print('Epoch',epoch,'time:',round((t1 - t0)/60.0,2),'minutes') + if self.use_test_set: self.test(DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)) + self.save_final_summary() + + def set_up_loss_function(self): + if self.loss == 'bce': + self.loss_func = nn.BCEWithLogitsLoss() #includes application of sigmoid for numerical stability + + def train(self, dataloader, epoch): + model = self.model.train() + epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=True) + self.train_loss[epoch] = epoch_loss + self.plot_roc_and_pr_curves('train', epoch, pred_epoch, gr_truth_epoch) + print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Train Loss', epoch_loss)) + + def valid(self, dataloader, epoch): + model = self.model.eval() + with torch.no_grad(): + epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False) + self.valid_loss[epoch] = epoch_loss + self.eval_results_valid = evaluate.evaluate_all(self.eval_results_valid, epoch, + self.label_meanings, gr_truth_epoch, pred_epoch) + self.early_stopping_check(epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch) + print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Valid Loss', epoch_loss)) + + def early_stopping_check(self, epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch): + """Check whether criteria for early stopping are met and update + counters accordingly""" + val_loss = self.valid_loss[epoch] + if (val_loss < self.min_val_loss) or epoch==0: #then save parameters + self.min_val_loss = val_loss + check_point = {'params': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict()} + torch.save(check_point, os.path.join(self.params_dir, self.descriptor)) + self.best_valid_epoch = epoch + self.patience_remaining = self.initial_patience + print('model saved, val loss',val_loss) + self.plot_roc_and_pr_curves('valid', epoch, val_pred_epoch, val_gr_truth_epoch) + self.save_all_pred_probs('valid', epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch) + else: + self.patience_remaining -= 1 + + def back_up_model_every_ten(self, epoch): + """Back up the model parameters every 10 epochs""" + if epoch % 10 == 0: + check_point = {'params': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict()} + torch.save(check_point, os.path.join(self.backup_dir, self.descriptor+'_ep_'+str(epoch))) + + def test(self, dataloader): + epoch = self.best_valid_epoch + if self.data_parallel: + model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device).eval() + else: + model = self.custom_net(**self.custom_net_args).to(self.device).eval() + params_path = os.path.join(self.params_dir,self.descriptor) + print('For test set predictions, loading model params from params_path=',params_path) + check_point = torch.load(params_path) + model.load_state_dict(check_point['params']) + with torch.no_grad(): + epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False) + self.eval_results_test = evaluate.evaluate_all(self.eval_results_test, epoch, + self.label_meanings, gr_truth_epoch, pred_epoch) + self.plot_roc_and_pr_curves('test', epoch, pred_epoch, gr_truth_epoch) + self.save_all_pred_probs('test', epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch) + print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Test Loss', epoch_loss)) + + def iterate_through_batches(self, model, dataloader, epoch, training): + epoch_loss = 0 + + #Initialize numpy arrays for storing results. examples x labels + #Do NOT use concatenation, or else you will have memory fragmentation. + num_examples = len(dataloader.dataset) + num_labels = len(self.label_meanings) + pred_epoch = np.zeros([num_examples,num_labels]) + gr_truth_epoch = np.zeros([num_examples,num_labels]) + volume_accs_epoch = np.empty(num_examples,dtype='U32') #need to use U32 to allow string of length 32 + + for batch_idx, batch in enumerate(dataloader): + data, gr_truth = self.move_data_to_device(batch) + self.optimizer.zero_grad() + if training: + out = model(data) + else: + with torch.set_grad_enabled(False): + out = model(data) + loss = self.loss_func(out, gr_truth) + if training: + loss.backward() + self.optimizer.step() + + epoch_loss += loss.item() + torch.cuda.empty_cache() + + #Save predictions and ground truth across batches + pred = self.sigmoid(out.data).detach().cpu().numpy() + gr_truth = gr_truth.detach().cpu().numpy() + + start_row = batch_idx*self.batch_size + stop_row = min(start_row + self.batch_size, num_examples) + pred_epoch[start_row:stop_row,:] = pred #pred_epoch is e.g. [25355,80] and pred is e.g. [1,80] for a batch size of 1 + gr_truth_epoch[start_row:stop_row,:] = gr_truth #gr_truth_epoch has same shape as pred_epoch + volume_accs_epoch[start_row:stop_row] = batch['volume_acc'] #volume_accs_epoch stores the volume accessions in the order they were used + + #the following line to empty the cache is necessary in order to + #reduce memory usage and avoid OOM error: + torch.cuda.empty_cache() + return epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch + + def move_data_to_device(self, batch): + """Move data and ground truth to device.""" + assert self.dataset_args['crop_type'] == 'single' + if self.dataset_args['crop_type'] == 'single': + data = batch['data'].to(self.device) + + #Ground truth to device + gr_truth = batch['gr_truth'].to(self.device) + return data, gr_truth + + def plot_roc_and_pr_curves(self, setname, epoch, pred_epoch, gr_truth_epoch): + outdir = os.path.join(self.results_dir,'curves') + if not os.path.isdir(outdir): + os.mkdir(outdir) + evaluate.plot_roc_curve_multi_class(label_meanings=self.label_meanings, + y_test=gr_truth_epoch, y_score=pred_epoch, + outdir = outdir, setname = setname, epoch = epoch) + evaluate.plot_pr_curve_multi_class(label_meanings=self.label_meanings, + y_test=gr_truth_epoch, y_score=pred_epoch, + outdir = outdir, setname = setname, epoch = epoch) + + def save_all_pred_probs(self, setname, epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch): + outdir = os.path.join(self.results_dir,'pred_probs') + if not os.path.isdir(outdir): + os.mkdir(outdir) + (pd.DataFrame(pred_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_predprob_ep'+str(epoch)+'.csv')) + (pd.DataFrame(gr_truth_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_grtruth_ep'+str(epoch)+'.csv')) + + def save_evals(self, epoch): + evaluate.save(self.eval_results_valid, self.results_dir, self.descriptor+'_valid') + if self.use_test_set: evaluate.save(self.eval_results_test, self.results_dir, self.descriptor+'_test') + evaluate.plot_learning_curves(self.train_loss, self.valid_loss, self.results_dir, self.descriptor) + + def save_final_summary(self): + evaluate.save_final_summary(self.eval_results_valid, self.best_valid_epoch, 'valid', self.results_dir) + if self.use_test_set: evaluate.save_final_summary(self.eval_results_test, self.best_valid_epoch, 'test', self.results_dir) + evaluate.clean_up_output_files(self.best_valid_epoch, self.results_dir) + \ No newline at end of file