Diff of /run_experiment.py [000000] .. [09c10e]

Switch to side-by-side view

--- a
+++ b/run_experiment.py
@@ -0,0 +1,330 @@
+#run_experiment.py
+#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
+
+#MIT License
+
+#Permission is hereby granted, free of charge, to any person obtaining a copy
+#of this software and associated documentation files (the "Software"), to deal
+#in the Software without restriction, including without limitation the rights
+#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+#copies of the Software, and to permit persons to whom the Software is
+#furnished to do so, subject to the following conditions:
+
+#The above copyright notice and this permission notice shall be included in all
+#copies or substantial portions of the Software.
+
+#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+#SOFTWARE
+
+import os
+import timeit
+import datetime
+import numpy as np
+import pandas as pd
+
+from torch.utils.data import Dataset, DataLoader
+import torch, torch.nn as nn, torch.nn.functional as F
+import torchvision
+from torchvision import transforms, models, utils
+
+import evaluate
+from load_dataset import custom_datasets
+
+#Set seeds
+np.random.seed(0)
+torch.manual_seed(0)
+torch.cuda.manual_seed(0)
+torch.cuda.manual_seed_all(0)
+
+class DukeCTModel(object):
+    def __init__(self, descriptor, custom_net, custom_net_args,
+                 loss, loss_args, num_epochs, patience, batch_size, device, data_parallel,
+                 use_test_set, task, old_params_dir, dataset_class, dataset_args):
+        """Variables:
+        <descriptor>: string describing the experiment
+        <custom_net>: class defining a model
+        <custom_net_args>: dictionary where keys correspond to custom net
+            input arguments, and values are the desired values    
+        <loss>: 'bce' for binary cross entropy
+        <loss_args>: arguments to pass to the loss function if any
+        <num_epochs>: int for the maximum number of epochs to train
+        <patience>: number of epochs for which loss must fail to improve to
+            cause early stopping
+        <batch_size>: int for number of examples per batch
+        <device>: int specifying which device to use, or 'all' for all devices
+        <data_parallel>: if True then parallelize across available GPUs.
+        <use_test_set>: if True, then run model on the test set. If False, use
+            only the training and validation sets.
+        <task>:
+            'train_eval': train and evaluate a new model. 'evaluate' will
+                always imply use of the validation set. if <use_test_set> is
+                True, then 'evaluate' also includes calculation of test set
+                performance for the best validation epoch.
+            'predict_on_test': load a trained model and make predictions on
+                the test set using that model.
+        <old_params_dir>: this is only needed if <task>=='predict_on_test'. This
+            is the path to the parameters that will be loaded in to the model.
+        <dataset_class>: CT Dataset class for preprocessing the data
+        <dataset_args>: arguments for the dataset class specifying how
+            the data should be prepared."""
+        self.descriptor = descriptor
+        self.set_up_results_dirs()
+        self.custom_net = custom_net
+        self.custom_net_args = custom_net_args
+        self.loss = loss
+        self.loss_args = loss_args
+        self.num_epochs = num_epochs
+        self.batch_size = batch_size
+        print('self.batch_size=',self.batch_size)
+        #num_workers is number of threads to use for data loading
+        self.num_workers = int(batch_size*4) #batch_size 1 = num_workers 4. batch_size 2 = num workers 8. batch_size 4 = num_workers 16.
+        print('self.num_workers=',self.num_workers)
+        if self.num_workers == 1:
+            print('Warning: Using only one worker will slow down data loading')
+        
+        #Set Device and Data Parallelism
+        if device in [0,1,2,3]: #i.e. if a GPU number was specified:
+            self.device = torch.device('cuda:'+str(device))
+            print('using device:',str(self.device),'\ndescriptor: ',self.descriptor)
+        elif device == 'all':
+            self.device = torch.device('cuda')
+        self.data_parallel = data_parallel
+        if self.data_parallel:
+            assert device == 'all' #use all devices when running data parallel
+        
+        #Set Task
+        self.use_test_set = use_test_set
+        self.task = task
+        assert self.task in ['train_eval','predict_on_test']
+        if self.task == 'predict_on_test':
+            #overwrite the params dir that was created in the call to
+            #set_up_results_dirs() with the dir you want to load from
+            self.params_dir = old_params_dir
+        
+        #Data and Labels
+        self.CTDatasetClass = dataset_class
+        self.dataset_args = dataset_args
+        #Get label meanings, a list of descriptive strings (list elements must
+        #be strings found in the column headers of the labels file)
+        self.set_up_label_meanings(self.dataset_args['label_meanings'])
+        if self.task == 'train_eval':
+            self.dataset_train = self.CTDatasetClass(setname = 'train', **self.dataset_args)
+            self.dataset_valid = self.CTDatasetClass(setname = 'valid', **self.dataset_args)
+        if self.use_test_set:
+            self.dataset_test = self.CTDatasetClass(setname = 'test', **self.dataset_args)
+        
+        #Tracking losses and evaluation results
+        self.train_loss = np.zeros((self.num_epochs))
+        self.valid_loss = np.zeros((self.num_epochs))
+        self.eval_results_valid, self.eval_results_test = evaluate.initialize_evaluation_dfs(self.label_meanings, self.num_epochs)
+        
+        #For early stopping
+        self.initial_patience = patience
+        self.patience_remaining = patience
+        self.best_valid_epoch = 0
+        self.min_val_loss = np.inf
+        
+        #Run everything
+        self.run_model()
+    
+    ### Methods ###
+    def set_up_label_meanings(self,label_meanings):
+        if label_meanings == 'all': #get full list of all available labels
+            temp = custom_datasets.read_in_labels(self.dataset_args['label_type_ld'], 'valid')
+            self.label_meanings = temp.columns.values.tolist()
+        else: #use the label meanings that were passed in
+            self.label_meanings = label_meanings
+        print('label meanings ('+str(len(self.label_meanings))+' labels total):',self.label_meanings)
+        
+    def set_up_results_dirs(self):
+        if not os.path.isdir('results'):
+            os.mkdir('results')
+        self.results_dir = os.path.join('results',datetime.datetime.today().strftime('%Y-%m-%d')+'_'+self.descriptor)
+        if not os.path.isdir(self.results_dir):
+            os.mkdir(self.results_dir)
+        self.params_dir = os.path.join(self.results_dir,'params')
+        if not os.path.isdir(self.params_dir):
+            os.mkdir(self.params_dir)
+        self.backup_dir = os.path.join(self.results_dir,'backup')
+        if not os.path.isdir(self.backup_dir):
+            os.mkdir(self.backup_dir)
+        
+    def run_model(self):
+        if self.data_parallel:
+            self.model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device)
+        else:
+            self.model = self.custom_net(**self.custom_net_args).to(self.device)
+        self.sigmoid = torch.nn.Sigmoid()
+        self.set_up_loss_function()
+        
+        momentum = 0.99
+        print('Running with optimizer lr=1e-3, momentum='+str(round(momentum,2))+' and weight_decay=1e-7')
+        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = 1e-3, momentum=momentum, weight_decay=1e-7)
+        
+        train_dataloader = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers = self.num_workers)
+        valid_dataloader = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)
+        
+        if self.task == 'train_eval':
+            for epoch in range(self.num_epochs):
+                t0 = timeit.default_timer()
+                self.train(train_dataloader, epoch)
+                self.valid(valid_dataloader, epoch)
+                self.save_evals(epoch)
+                if self.patience_remaining <= 0:
+                    print('No more patience (',self.initial_patience,') left at epoch',epoch)
+                    print('--> Implementing early stopping. Best epoch was:',self.best_valid_epoch)
+                    break
+                t1 = timeit.default_timer()
+                self.back_up_model_every_ten(epoch)
+                print('Epoch',epoch,'time:',round((t1 - t0)/60.0,2),'minutes')  
+        if self.use_test_set: self.test(DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers))
+        self.save_final_summary()
+    
+    def set_up_loss_function(self):
+        if self.loss == 'bce': 
+            self.loss_func = nn.BCEWithLogitsLoss() #includes application of sigmoid for numerical stability
+    
+    def train(self, dataloader, epoch):
+        model = self.model.train()
+        epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=True)
+        self.train_loss[epoch] = epoch_loss
+        self.plot_roc_and_pr_curves('train', epoch, pred_epoch, gr_truth_epoch)
+        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Train Loss', epoch_loss))
+        
+    def valid(self, dataloader, epoch):
+        model = self.model.eval()
+        with torch.no_grad():
+            epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
+        self.valid_loss[epoch] = epoch_loss
+        self.eval_results_valid = evaluate.evaluate_all(self.eval_results_valid, epoch,
+            self.label_meanings, gr_truth_epoch, pred_epoch)
+        self.early_stopping_check(epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
+        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Valid Loss', epoch_loss))
+    
+    def early_stopping_check(self, epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch):
+        """Check whether criteria for early stopping are met and update
+        counters accordingly"""
+        val_loss = self.valid_loss[epoch]
+        if (val_loss < self.min_val_loss) or epoch==0: #then save parameters
+            self.min_val_loss = val_loss
+            check_point = {'params': self.model.state_dict(),                            
+                           'optimizer': self.optimizer.state_dict()}
+            torch.save(check_point, os.path.join(self.params_dir, self.descriptor))                                 
+            self.best_valid_epoch = epoch
+            self.patience_remaining = self.initial_patience
+            print('model saved, val loss',val_loss)
+            self.plot_roc_and_pr_curves('valid', epoch, val_pred_epoch, val_gr_truth_epoch)
+            self.save_all_pred_probs('valid', epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch)
+        else:
+            self.patience_remaining -= 1
+    
+    def back_up_model_every_ten(self, epoch):
+        """Back up the model parameters every 10 epochs"""
+        if epoch % 10 == 0:
+            check_point = {'params': self.model.state_dict(),                            
+                           'optimizer': self.optimizer.state_dict()}
+            torch.save(check_point, os.path.join(self.backup_dir, self.descriptor+'_ep_'+str(epoch)))   
+    
+    def test(self, dataloader):
+        epoch = self.best_valid_epoch
+        if self.data_parallel:
+            model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device).eval()
+        else:
+            model = self.custom_net(**self.custom_net_args).to(self.device).eval()
+        params_path = os.path.join(self.params_dir,self.descriptor)
+        print('For test set predictions, loading model params from params_path=',params_path)
+        check_point = torch.load(params_path)
+        model.load_state_dict(check_point['params'])
+        with torch.no_grad():
+            epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
+        self.eval_results_test = evaluate.evaluate_all(self.eval_results_test, epoch,
+            self.label_meanings, gr_truth_epoch, pred_epoch)
+        self.plot_roc_and_pr_curves('test', epoch, pred_epoch, gr_truth_epoch)
+        self.save_all_pred_probs('test', epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
+        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Test Loss', epoch_loss))
+    
+    def iterate_through_batches(self, model, dataloader, epoch, training):
+        epoch_loss = 0
+        
+        #Initialize numpy arrays for storing results. examples x labels
+        #Do NOT use concatenation, or else you will have memory fragmentation.
+        num_examples = len(dataloader.dataset)
+        num_labels = len(self.label_meanings)
+        pred_epoch = np.zeros([num_examples,num_labels])
+        gr_truth_epoch = np.zeros([num_examples,num_labels])
+        volume_accs_epoch = np.empty(num_examples,dtype='U32') #need to use U32 to allow string of length 32
+        
+        for batch_idx, batch in enumerate(dataloader):
+            data, gr_truth = self.move_data_to_device(batch)
+            self.optimizer.zero_grad()
+            if training:
+                out = model(data)
+            else:
+                with torch.set_grad_enabled(False):
+                   out = model(data)
+            loss = self.loss_func(out, gr_truth)
+            if training:
+                loss.backward()
+                self.optimizer.step()   
+            
+            epoch_loss += loss.item()
+            torch.cuda.empty_cache()
+            
+            #Save predictions and ground truth across batches
+            pred = self.sigmoid(out.data).detach().cpu().numpy()
+            gr_truth = gr_truth.detach().cpu().numpy()
+            
+            start_row = batch_idx*self.batch_size
+            stop_row = min(start_row + self.batch_size, num_examples)
+            pred_epoch[start_row:stop_row,:] = pred #pred_epoch is e.g. [25355,80] and pred is e.g. [1,80] for a batch size of 1
+            gr_truth_epoch[start_row:stop_row,:] = gr_truth #gr_truth_epoch has same shape as pred_epoch
+            volume_accs_epoch[start_row:stop_row] = batch['volume_acc'] #volume_accs_epoch stores the volume accessions in the order they were used
+            
+            #the following line to empty the cache is necessary in order to
+            #reduce memory usage and avoid OOM error:
+            torch.cuda.empty_cache() 
+        return epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch
+    
+    def move_data_to_device(self, batch):
+        """Move data and ground truth to device."""
+        assert self.dataset_args['crop_type'] == 'single'
+        if self.dataset_args['crop_type'] == 'single':
+            data = batch['data'].to(self.device)
+        
+        #Ground truth to device
+        gr_truth = batch['gr_truth'].to(self.device)
+        return data, gr_truth
+    
+    def plot_roc_and_pr_curves(self, setname, epoch, pred_epoch, gr_truth_epoch):
+        outdir = os.path.join(self.results_dir,'curves')
+        if not os.path.isdir(outdir):
+            os.mkdir(outdir)
+        evaluate.plot_roc_curve_multi_class(label_meanings=self.label_meanings,
+                    y_test=gr_truth_epoch, y_score=pred_epoch,
+                    outdir = outdir, setname = setname, epoch = epoch)
+        evaluate.plot_pr_curve_multi_class(label_meanings=self.label_meanings,
+                    y_test=gr_truth_epoch, y_score=pred_epoch,
+                    outdir = outdir, setname = setname, epoch = epoch)
+    
+    def save_all_pred_probs(self, setname, epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch):
+        outdir = os.path.join(self.results_dir,'pred_probs')
+        if not os.path.isdir(outdir):
+            os.mkdir(outdir)
+        (pd.DataFrame(pred_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_predprob_ep'+str(epoch)+'.csv'))
+        (pd.DataFrame(gr_truth_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_grtruth_ep'+str(epoch)+'.csv'))
+        
+    def save_evals(self, epoch):
+        evaluate.save(self.eval_results_valid, self.results_dir, self.descriptor+'_valid')
+        if self.use_test_set: evaluate.save(self.eval_results_test, self.results_dir, self.descriptor+'_test')
+        evaluate.plot_learning_curves(self.train_loss, self.valid_loss, self.results_dir, self.descriptor)
+               
+    def save_final_summary(self):
+        evaluate.save_final_summary(self.eval_results_valid, self.best_valid_epoch, 'valid', self.results_dir)
+        if self.use_test_set: evaluate.save_final_summary(self.eval_results_test, self.best_valid_epoch, 'test', self.results_dir)
+        evaluate.clean_up_output_files(self.best_valid_epoch, self.results_dir)
+        
\ No newline at end of file