a b/run_experiment.py
1
#run_experiment.py
2
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
3
4
#MIT License
5
6
#Permission is hereby granted, free of charge, to any person obtaining a copy
7
#of this software and associated documentation files (the "Software"), to deal
8
#in the Software without restriction, including without limitation the rights
9
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
#copies of the Software, and to permit persons to whom the Software is
11
#furnished to do so, subject to the following conditions:
12
13
#The above copyright notice and this permission notice shall be included in all
14
#copies or substantial portions of the Software.
15
16
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
#SOFTWARE
23
24
import os
25
import timeit
26
import datetime
27
import numpy as np
28
import pandas as pd
29
30
from torch.utils.data import Dataset, DataLoader
31
import torch, torch.nn as nn, torch.nn.functional as F
32
import torchvision
33
from torchvision import transforms, models, utils
34
35
import evaluate
36
from load_dataset import custom_datasets
37
38
#Set seeds
39
np.random.seed(0)
40
torch.manual_seed(0)
41
torch.cuda.manual_seed(0)
42
torch.cuda.manual_seed_all(0)
43
44
class DukeCTModel(object):
45
    def __init__(self, descriptor, custom_net, custom_net_args,
46
                 loss, loss_args, num_epochs, patience, batch_size, device, data_parallel,
47
                 use_test_set, task, old_params_dir, dataset_class, dataset_args):
48
        """Variables:
49
        <descriptor>: string describing the experiment
50
        <custom_net>: class defining a model
51
        <custom_net_args>: dictionary where keys correspond to custom net
52
            input arguments, and values are the desired values    
53
        <loss>: 'bce' for binary cross entropy
54
        <loss_args>: arguments to pass to the loss function if any
55
        <num_epochs>: int for the maximum number of epochs to train
56
        <patience>: number of epochs for which loss must fail to improve to
57
            cause early stopping
58
        <batch_size>: int for number of examples per batch
59
        <device>: int specifying which device to use, or 'all' for all devices
60
        <data_parallel>: if True then parallelize across available GPUs.
61
        <use_test_set>: if True, then run model on the test set. If False, use
62
            only the training and validation sets.
63
        <task>:
64
            'train_eval': train and evaluate a new model. 'evaluate' will
65
                always imply use of the validation set. if <use_test_set> is
66
                True, then 'evaluate' also includes calculation of test set
67
                performance for the best validation epoch.
68
            'predict_on_test': load a trained model and make predictions on
69
                the test set using that model.
70
        <old_params_dir>: this is only needed if <task>=='predict_on_test'. This
71
            is the path to the parameters that will be loaded in to the model.
72
        <dataset_class>: CT Dataset class for preprocessing the data
73
        <dataset_args>: arguments for the dataset class specifying how
74
            the data should be prepared."""
75
        self.descriptor = descriptor
76
        self.set_up_results_dirs()
77
        self.custom_net = custom_net
78
        self.custom_net_args = custom_net_args
79
        self.loss = loss
80
        self.loss_args = loss_args
81
        self.num_epochs = num_epochs
82
        self.batch_size = batch_size
83
        print('self.batch_size=',self.batch_size)
84
        #num_workers is number of threads to use for data loading
85
        self.num_workers = int(batch_size*4) #batch_size 1 = num_workers 4. batch_size 2 = num workers 8. batch_size 4 = num_workers 16.
86
        print('self.num_workers=',self.num_workers)
87
        if self.num_workers == 1:
88
            print('Warning: Using only one worker will slow down data loading')
89
        
90
        #Set Device and Data Parallelism
91
        if device in [0,1,2,3]: #i.e. if a GPU number was specified:
92
            self.device = torch.device('cuda:'+str(device))
93
            print('using device:',str(self.device),'\ndescriptor: ',self.descriptor)
94
        elif device == 'all':
95
            self.device = torch.device('cuda')
96
        self.data_parallel = data_parallel
97
        if self.data_parallel:
98
            assert device == 'all' #use all devices when running data parallel
99
        
100
        #Set Task
101
        self.use_test_set = use_test_set
102
        self.task = task
103
        assert self.task in ['train_eval','predict_on_test']
104
        if self.task == 'predict_on_test':
105
            #overwrite the params dir that was created in the call to
106
            #set_up_results_dirs() with the dir you want to load from
107
            self.params_dir = old_params_dir
108
        
109
        #Data and Labels
110
        self.CTDatasetClass = dataset_class
111
        self.dataset_args = dataset_args
112
        #Get label meanings, a list of descriptive strings (list elements must
113
        #be strings found in the column headers of the labels file)
114
        self.set_up_label_meanings(self.dataset_args['label_meanings'])
115
        if self.task == 'train_eval':
116
            self.dataset_train = self.CTDatasetClass(setname = 'train', **self.dataset_args)
117
            self.dataset_valid = self.CTDatasetClass(setname = 'valid', **self.dataset_args)
118
        if self.use_test_set:
119
            self.dataset_test = self.CTDatasetClass(setname = 'test', **self.dataset_args)
120
        
121
        #Tracking losses and evaluation results
122
        self.train_loss = np.zeros((self.num_epochs))
123
        self.valid_loss = np.zeros((self.num_epochs))
124
        self.eval_results_valid, self.eval_results_test = evaluate.initialize_evaluation_dfs(self.label_meanings, self.num_epochs)
125
        
126
        #For early stopping
127
        self.initial_patience = patience
128
        self.patience_remaining = patience
129
        self.best_valid_epoch = 0
130
        self.min_val_loss = np.inf
131
        
132
        #Run everything
133
        self.run_model()
134
    
135
    ### Methods ###
136
    def set_up_label_meanings(self,label_meanings):
137
        if label_meanings == 'all': #get full list of all available labels
138
            temp = custom_datasets.read_in_labels(self.dataset_args['label_type_ld'], 'valid')
139
            self.label_meanings = temp.columns.values.tolist()
140
        else: #use the label meanings that were passed in
141
            self.label_meanings = label_meanings
142
        print('label meanings ('+str(len(self.label_meanings))+' labels total):',self.label_meanings)
143
        
144
    def set_up_results_dirs(self):
145
        if not os.path.isdir('results'):
146
            os.mkdir('results')
147
        self.results_dir = os.path.join('results',datetime.datetime.today().strftime('%Y-%m-%d')+'_'+self.descriptor)
148
        if not os.path.isdir(self.results_dir):
149
            os.mkdir(self.results_dir)
150
        self.params_dir = os.path.join(self.results_dir,'params')
151
        if not os.path.isdir(self.params_dir):
152
            os.mkdir(self.params_dir)
153
        self.backup_dir = os.path.join(self.results_dir,'backup')
154
        if not os.path.isdir(self.backup_dir):
155
            os.mkdir(self.backup_dir)
156
        
157
    def run_model(self):
158
        if self.data_parallel:
159
            self.model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device)
160
        else:
161
            self.model = self.custom_net(**self.custom_net_args).to(self.device)
162
        self.sigmoid = torch.nn.Sigmoid()
163
        self.set_up_loss_function()
164
        
165
        momentum = 0.99
166
        print('Running with optimizer lr=1e-3, momentum='+str(round(momentum,2))+' and weight_decay=1e-7')
167
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr = 1e-3, momentum=momentum, weight_decay=1e-7)
168
        
169
        train_dataloader = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers = self.num_workers)
170
        valid_dataloader = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)
171
        
172
        if self.task == 'train_eval':
173
            for epoch in range(self.num_epochs):
174
                t0 = timeit.default_timer()
175
                self.train(train_dataloader, epoch)
176
                self.valid(valid_dataloader, epoch)
177
                self.save_evals(epoch)
178
                if self.patience_remaining <= 0:
179
                    print('No more patience (',self.initial_patience,') left at epoch',epoch)
180
                    print('--> Implementing early stopping. Best epoch was:',self.best_valid_epoch)
181
                    break
182
                t1 = timeit.default_timer()
183
                self.back_up_model_every_ten(epoch)
184
                print('Epoch',epoch,'time:',round((t1 - t0)/60.0,2),'minutes')  
185
        if self.use_test_set: self.test(DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers))
186
        self.save_final_summary()
187
    
188
    def set_up_loss_function(self):
189
        if self.loss == 'bce': 
190
            self.loss_func = nn.BCEWithLogitsLoss() #includes application of sigmoid for numerical stability
191
    
192
    def train(self, dataloader, epoch):
193
        model = self.model.train()
194
        epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=True)
195
        self.train_loss[epoch] = epoch_loss
196
        self.plot_roc_and_pr_curves('train', epoch, pred_epoch, gr_truth_epoch)
197
        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Train Loss', epoch_loss))
198
        
199
    def valid(self, dataloader, epoch):
200
        model = self.model.eval()
201
        with torch.no_grad():
202
            epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
203
        self.valid_loss[epoch] = epoch_loss
204
        self.eval_results_valid = evaluate.evaluate_all(self.eval_results_valid, epoch,
205
            self.label_meanings, gr_truth_epoch, pred_epoch)
206
        self.early_stopping_check(epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
207
        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Valid Loss', epoch_loss))
208
    
209
    def early_stopping_check(self, epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch):
210
        """Check whether criteria for early stopping are met and update
211
        counters accordingly"""
212
        val_loss = self.valid_loss[epoch]
213
        if (val_loss < self.min_val_loss) or epoch==0: #then save parameters
214
            self.min_val_loss = val_loss
215
            check_point = {'params': self.model.state_dict(),                            
216
                           'optimizer': self.optimizer.state_dict()}
217
            torch.save(check_point, os.path.join(self.params_dir, self.descriptor))                                 
218
            self.best_valid_epoch = epoch
219
            self.patience_remaining = self.initial_patience
220
            print('model saved, val loss',val_loss)
221
            self.plot_roc_and_pr_curves('valid', epoch, val_pred_epoch, val_gr_truth_epoch)
222
            self.save_all_pred_probs('valid', epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch)
223
        else:
224
            self.patience_remaining -= 1
225
    
226
    def back_up_model_every_ten(self, epoch):
227
        """Back up the model parameters every 10 epochs"""
228
        if epoch % 10 == 0:
229
            check_point = {'params': self.model.state_dict(),                            
230
                           'optimizer': self.optimizer.state_dict()}
231
            torch.save(check_point, os.path.join(self.backup_dir, self.descriptor+'_ep_'+str(epoch)))   
232
    
233
    def test(self, dataloader):
234
        epoch = self.best_valid_epoch
235
        if self.data_parallel:
236
            model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device).eval()
237
        else:
238
            model = self.custom_net(**self.custom_net_args).to(self.device).eval()
239
        params_path = os.path.join(self.params_dir,self.descriptor)
240
        print('For test set predictions, loading model params from params_path=',params_path)
241
        check_point = torch.load(params_path)
242
        model.load_state_dict(check_point['params'])
243
        with torch.no_grad():
244
            epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
245
        self.eval_results_test = evaluate.evaluate_all(self.eval_results_test, epoch,
246
            self.label_meanings, gr_truth_epoch, pred_epoch)
247
        self.plot_roc_and_pr_curves('test', epoch, pred_epoch, gr_truth_epoch)
248
        self.save_all_pred_probs('test', epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
249
        print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Test Loss', epoch_loss))
250
    
251
    def iterate_through_batches(self, model, dataloader, epoch, training):
252
        epoch_loss = 0
253
        
254
        #Initialize numpy arrays for storing results. examples x labels
255
        #Do NOT use concatenation, or else you will have memory fragmentation.
256
        num_examples = len(dataloader.dataset)
257
        num_labels = len(self.label_meanings)
258
        pred_epoch = np.zeros([num_examples,num_labels])
259
        gr_truth_epoch = np.zeros([num_examples,num_labels])
260
        volume_accs_epoch = np.empty(num_examples,dtype='U32') #need to use U32 to allow string of length 32
261
        
262
        for batch_idx, batch in enumerate(dataloader):
263
            data, gr_truth = self.move_data_to_device(batch)
264
            self.optimizer.zero_grad()
265
            if training:
266
                out = model(data)
267
            else:
268
                with torch.set_grad_enabled(False):
269
                   out = model(data)
270
            loss = self.loss_func(out, gr_truth)
271
            if training:
272
                loss.backward()
273
                self.optimizer.step()   
274
            
275
            epoch_loss += loss.item()
276
            torch.cuda.empty_cache()
277
            
278
            #Save predictions and ground truth across batches
279
            pred = self.sigmoid(out.data).detach().cpu().numpy()
280
            gr_truth = gr_truth.detach().cpu().numpy()
281
            
282
            start_row = batch_idx*self.batch_size
283
            stop_row = min(start_row + self.batch_size, num_examples)
284
            pred_epoch[start_row:stop_row,:] = pred #pred_epoch is e.g. [25355,80] and pred is e.g. [1,80] for a batch size of 1
285
            gr_truth_epoch[start_row:stop_row,:] = gr_truth #gr_truth_epoch has same shape as pred_epoch
286
            volume_accs_epoch[start_row:stop_row] = batch['volume_acc'] #volume_accs_epoch stores the volume accessions in the order they were used
287
            
288
            #the following line to empty the cache is necessary in order to
289
            #reduce memory usage and avoid OOM error:
290
            torch.cuda.empty_cache() 
291
        return epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch
292
    
293
    def move_data_to_device(self, batch):
294
        """Move data and ground truth to device."""
295
        assert self.dataset_args['crop_type'] == 'single'
296
        if self.dataset_args['crop_type'] == 'single':
297
            data = batch['data'].to(self.device)
298
        
299
        #Ground truth to device
300
        gr_truth = batch['gr_truth'].to(self.device)
301
        return data, gr_truth
302
    
303
    def plot_roc_and_pr_curves(self, setname, epoch, pred_epoch, gr_truth_epoch):
304
        outdir = os.path.join(self.results_dir,'curves')
305
        if not os.path.isdir(outdir):
306
            os.mkdir(outdir)
307
        evaluate.plot_roc_curve_multi_class(label_meanings=self.label_meanings,
308
                    y_test=gr_truth_epoch, y_score=pred_epoch,
309
                    outdir = outdir, setname = setname, epoch = epoch)
310
        evaluate.plot_pr_curve_multi_class(label_meanings=self.label_meanings,
311
                    y_test=gr_truth_epoch, y_score=pred_epoch,
312
                    outdir = outdir, setname = setname, epoch = epoch)
313
    
314
    def save_all_pred_probs(self, setname, epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch):
315
        outdir = os.path.join(self.results_dir,'pred_probs')
316
        if not os.path.isdir(outdir):
317
            os.mkdir(outdir)
318
        (pd.DataFrame(pred_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_predprob_ep'+str(epoch)+'.csv'))
319
        (pd.DataFrame(gr_truth_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_grtruth_ep'+str(epoch)+'.csv'))
320
        
321
    def save_evals(self, epoch):
322
        evaluate.save(self.eval_results_valid, self.results_dir, self.descriptor+'_valid')
323
        if self.use_test_set: evaluate.save(self.eval_results_test, self.results_dir, self.descriptor+'_test')
324
        evaluate.plot_learning_curves(self.train_loss, self.valid_loss, self.results_dir, self.descriptor)
325
               
326
    def save_final_summary(self):
327
        evaluate.save_final_summary(self.eval_results_valid, self.best_valid_epoch, 'valid', self.results_dir)
328
        if self.use_test_set: evaluate.save_final_summary(self.eval_results_test, self.best_valid_epoch, 'test', self.results_dir)
329
        evaluate.clean_up_output_files(self.best_valid_epoch, self.results_dir)
330