|
a |
|
b/run_experiment.py |
|
|
1 |
#run_experiment.py |
|
|
2 |
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos |
|
|
3 |
|
|
|
4 |
#MIT License |
|
|
5 |
|
|
|
6 |
#Permission is hereby granted, free of charge, to any person obtaining a copy |
|
|
7 |
#of this software and associated documentation files (the "Software"), to deal |
|
|
8 |
#in the Software without restriction, including without limitation the rights |
|
|
9 |
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
|
10 |
#copies of the Software, and to permit persons to whom the Software is |
|
|
11 |
#furnished to do so, subject to the following conditions: |
|
|
12 |
|
|
|
13 |
#The above copyright notice and this permission notice shall be included in all |
|
|
14 |
#copies or substantial portions of the Software. |
|
|
15 |
|
|
|
16 |
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
|
17 |
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
|
18 |
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
|
19 |
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
|
20 |
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
|
21 |
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
|
22 |
#SOFTWARE |
|
|
23 |
|
|
|
24 |
import os |
|
|
25 |
import timeit |
|
|
26 |
import datetime |
|
|
27 |
import numpy as np |
|
|
28 |
import pandas as pd |
|
|
29 |
|
|
|
30 |
from torch.utils.data import Dataset, DataLoader |
|
|
31 |
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
32 |
import torchvision |
|
|
33 |
from torchvision import transforms, models, utils |
|
|
34 |
|
|
|
35 |
import evaluate |
|
|
36 |
from load_dataset import custom_datasets |
|
|
37 |
|
|
|
38 |
#Set seeds |
|
|
39 |
np.random.seed(0) |
|
|
40 |
torch.manual_seed(0) |
|
|
41 |
torch.cuda.manual_seed(0) |
|
|
42 |
torch.cuda.manual_seed_all(0) |
|
|
43 |
|
|
|
44 |
class DukeCTModel(object): |
|
|
45 |
def __init__(self, descriptor, custom_net, custom_net_args, |
|
|
46 |
loss, loss_args, num_epochs, patience, batch_size, device, data_parallel, |
|
|
47 |
use_test_set, task, old_params_dir, dataset_class, dataset_args): |
|
|
48 |
"""Variables: |
|
|
49 |
<descriptor>: string describing the experiment |
|
|
50 |
<custom_net>: class defining a model |
|
|
51 |
<custom_net_args>: dictionary where keys correspond to custom net |
|
|
52 |
input arguments, and values are the desired values |
|
|
53 |
<loss>: 'bce' for binary cross entropy |
|
|
54 |
<loss_args>: arguments to pass to the loss function if any |
|
|
55 |
<num_epochs>: int for the maximum number of epochs to train |
|
|
56 |
<patience>: number of epochs for which loss must fail to improve to |
|
|
57 |
cause early stopping |
|
|
58 |
<batch_size>: int for number of examples per batch |
|
|
59 |
<device>: int specifying which device to use, or 'all' for all devices |
|
|
60 |
<data_parallel>: if True then parallelize across available GPUs. |
|
|
61 |
<use_test_set>: if True, then run model on the test set. If False, use |
|
|
62 |
only the training and validation sets. |
|
|
63 |
<task>: |
|
|
64 |
'train_eval': train and evaluate a new model. 'evaluate' will |
|
|
65 |
always imply use of the validation set. if <use_test_set> is |
|
|
66 |
True, then 'evaluate' also includes calculation of test set |
|
|
67 |
performance for the best validation epoch. |
|
|
68 |
'predict_on_test': load a trained model and make predictions on |
|
|
69 |
the test set using that model. |
|
|
70 |
<old_params_dir>: this is only needed if <task>=='predict_on_test'. This |
|
|
71 |
is the path to the parameters that will be loaded in to the model. |
|
|
72 |
<dataset_class>: CT Dataset class for preprocessing the data |
|
|
73 |
<dataset_args>: arguments for the dataset class specifying how |
|
|
74 |
the data should be prepared.""" |
|
|
75 |
self.descriptor = descriptor |
|
|
76 |
self.set_up_results_dirs() |
|
|
77 |
self.custom_net = custom_net |
|
|
78 |
self.custom_net_args = custom_net_args |
|
|
79 |
self.loss = loss |
|
|
80 |
self.loss_args = loss_args |
|
|
81 |
self.num_epochs = num_epochs |
|
|
82 |
self.batch_size = batch_size |
|
|
83 |
print('self.batch_size=',self.batch_size) |
|
|
84 |
#num_workers is number of threads to use for data loading |
|
|
85 |
self.num_workers = int(batch_size*4) #batch_size 1 = num_workers 4. batch_size 2 = num workers 8. batch_size 4 = num_workers 16. |
|
|
86 |
print('self.num_workers=',self.num_workers) |
|
|
87 |
if self.num_workers == 1: |
|
|
88 |
print('Warning: Using only one worker will slow down data loading') |
|
|
89 |
|
|
|
90 |
#Set Device and Data Parallelism |
|
|
91 |
if device in [0,1,2,3]: #i.e. if a GPU number was specified: |
|
|
92 |
self.device = torch.device('cuda:'+str(device)) |
|
|
93 |
print('using device:',str(self.device),'\ndescriptor: ',self.descriptor) |
|
|
94 |
elif device == 'all': |
|
|
95 |
self.device = torch.device('cuda') |
|
|
96 |
self.data_parallel = data_parallel |
|
|
97 |
if self.data_parallel: |
|
|
98 |
assert device == 'all' #use all devices when running data parallel |
|
|
99 |
|
|
|
100 |
#Set Task |
|
|
101 |
self.use_test_set = use_test_set |
|
|
102 |
self.task = task |
|
|
103 |
assert self.task in ['train_eval','predict_on_test'] |
|
|
104 |
if self.task == 'predict_on_test': |
|
|
105 |
#overwrite the params dir that was created in the call to |
|
|
106 |
#set_up_results_dirs() with the dir you want to load from |
|
|
107 |
self.params_dir = old_params_dir |
|
|
108 |
|
|
|
109 |
#Data and Labels |
|
|
110 |
self.CTDatasetClass = dataset_class |
|
|
111 |
self.dataset_args = dataset_args |
|
|
112 |
#Get label meanings, a list of descriptive strings (list elements must |
|
|
113 |
#be strings found in the column headers of the labels file) |
|
|
114 |
self.set_up_label_meanings(self.dataset_args['label_meanings']) |
|
|
115 |
if self.task == 'train_eval': |
|
|
116 |
self.dataset_train = self.CTDatasetClass(setname = 'train', **self.dataset_args) |
|
|
117 |
self.dataset_valid = self.CTDatasetClass(setname = 'valid', **self.dataset_args) |
|
|
118 |
if self.use_test_set: |
|
|
119 |
self.dataset_test = self.CTDatasetClass(setname = 'test', **self.dataset_args) |
|
|
120 |
|
|
|
121 |
#Tracking losses and evaluation results |
|
|
122 |
self.train_loss = np.zeros((self.num_epochs)) |
|
|
123 |
self.valid_loss = np.zeros((self.num_epochs)) |
|
|
124 |
self.eval_results_valid, self.eval_results_test = evaluate.initialize_evaluation_dfs(self.label_meanings, self.num_epochs) |
|
|
125 |
|
|
|
126 |
#For early stopping |
|
|
127 |
self.initial_patience = patience |
|
|
128 |
self.patience_remaining = patience |
|
|
129 |
self.best_valid_epoch = 0 |
|
|
130 |
self.min_val_loss = np.inf |
|
|
131 |
|
|
|
132 |
#Run everything |
|
|
133 |
self.run_model() |
|
|
134 |
|
|
|
135 |
### Methods ### |
|
|
136 |
def set_up_label_meanings(self,label_meanings): |
|
|
137 |
if label_meanings == 'all': #get full list of all available labels |
|
|
138 |
temp = custom_datasets.read_in_labels(self.dataset_args['label_type_ld'], 'valid') |
|
|
139 |
self.label_meanings = temp.columns.values.tolist() |
|
|
140 |
else: #use the label meanings that were passed in |
|
|
141 |
self.label_meanings = label_meanings |
|
|
142 |
print('label meanings ('+str(len(self.label_meanings))+' labels total):',self.label_meanings) |
|
|
143 |
|
|
|
144 |
def set_up_results_dirs(self): |
|
|
145 |
if not os.path.isdir('results'): |
|
|
146 |
os.mkdir('results') |
|
|
147 |
self.results_dir = os.path.join('results',datetime.datetime.today().strftime('%Y-%m-%d')+'_'+self.descriptor) |
|
|
148 |
if not os.path.isdir(self.results_dir): |
|
|
149 |
os.mkdir(self.results_dir) |
|
|
150 |
self.params_dir = os.path.join(self.results_dir,'params') |
|
|
151 |
if not os.path.isdir(self.params_dir): |
|
|
152 |
os.mkdir(self.params_dir) |
|
|
153 |
self.backup_dir = os.path.join(self.results_dir,'backup') |
|
|
154 |
if not os.path.isdir(self.backup_dir): |
|
|
155 |
os.mkdir(self.backup_dir) |
|
|
156 |
|
|
|
157 |
def run_model(self): |
|
|
158 |
if self.data_parallel: |
|
|
159 |
self.model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device) |
|
|
160 |
else: |
|
|
161 |
self.model = self.custom_net(**self.custom_net_args).to(self.device) |
|
|
162 |
self.sigmoid = torch.nn.Sigmoid() |
|
|
163 |
self.set_up_loss_function() |
|
|
164 |
|
|
|
165 |
momentum = 0.99 |
|
|
166 |
print('Running with optimizer lr=1e-3, momentum='+str(round(momentum,2))+' and weight_decay=1e-7') |
|
|
167 |
self.optimizer = torch.optim.SGD(self.model.parameters(), lr = 1e-3, momentum=momentum, weight_decay=1e-7) |
|
|
168 |
|
|
|
169 |
train_dataloader = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers = self.num_workers) |
|
|
170 |
valid_dataloader = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers) |
|
|
171 |
|
|
|
172 |
if self.task == 'train_eval': |
|
|
173 |
for epoch in range(self.num_epochs): |
|
|
174 |
t0 = timeit.default_timer() |
|
|
175 |
self.train(train_dataloader, epoch) |
|
|
176 |
self.valid(valid_dataloader, epoch) |
|
|
177 |
self.save_evals(epoch) |
|
|
178 |
if self.patience_remaining <= 0: |
|
|
179 |
print('No more patience (',self.initial_patience,') left at epoch',epoch) |
|
|
180 |
print('--> Implementing early stopping. Best epoch was:',self.best_valid_epoch) |
|
|
181 |
break |
|
|
182 |
t1 = timeit.default_timer() |
|
|
183 |
self.back_up_model_every_ten(epoch) |
|
|
184 |
print('Epoch',epoch,'time:',round((t1 - t0)/60.0,2),'minutes') |
|
|
185 |
if self.use_test_set: self.test(DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)) |
|
|
186 |
self.save_final_summary() |
|
|
187 |
|
|
|
188 |
def set_up_loss_function(self): |
|
|
189 |
if self.loss == 'bce': |
|
|
190 |
self.loss_func = nn.BCEWithLogitsLoss() #includes application of sigmoid for numerical stability |
|
|
191 |
|
|
|
192 |
def train(self, dataloader, epoch): |
|
|
193 |
model = self.model.train() |
|
|
194 |
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=True) |
|
|
195 |
self.train_loss[epoch] = epoch_loss |
|
|
196 |
self.plot_roc_and_pr_curves('train', epoch, pred_epoch, gr_truth_epoch) |
|
|
197 |
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Train Loss', epoch_loss)) |
|
|
198 |
|
|
|
199 |
def valid(self, dataloader, epoch): |
|
|
200 |
model = self.model.eval() |
|
|
201 |
with torch.no_grad(): |
|
|
202 |
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False) |
|
|
203 |
self.valid_loss[epoch] = epoch_loss |
|
|
204 |
self.eval_results_valid = evaluate.evaluate_all(self.eval_results_valid, epoch, |
|
|
205 |
self.label_meanings, gr_truth_epoch, pred_epoch) |
|
|
206 |
self.early_stopping_check(epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch) |
|
|
207 |
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Valid Loss', epoch_loss)) |
|
|
208 |
|
|
|
209 |
def early_stopping_check(self, epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch): |
|
|
210 |
"""Check whether criteria for early stopping are met and update |
|
|
211 |
counters accordingly""" |
|
|
212 |
val_loss = self.valid_loss[epoch] |
|
|
213 |
if (val_loss < self.min_val_loss) or epoch==0: #then save parameters |
|
|
214 |
self.min_val_loss = val_loss |
|
|
215 |
check_point = {'params': self.model.state_dict(), |
|
|
216 |
'optimizer': self.optimizer.state_dict()} |
|
|
217 |
torch.save(check_point, os.path.join(self.params_dir, self.descriptor)) |
|
|
218 |
self.best_valid_epoch = epoch |
|
|
219 |
self.patience_remaining = self.initial_patience |
|
|
220 |
print('model saved, val loss',val_loss) |
|
|
221 |
self.plot_roc_and_pr_curves('valid', epoch, val_pred_epoch, val_gr_truth_epoch) |
|
|
222 |
self.save_all_pred_probs('valid', epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch) |
|
|
223 |
else: |
|
|
224 |
self.patience_remaining -= 1 |
|
|
225 |
|
|
|
226 |
def back_up_model_every_ten(self, epoch): |
|
|
227 |
"""Back up the model parameters every 10 epochs""" |
|
|
228 |
if epoch % 10 == 0: |
|
|
229 |
check_point = {'params': self.model.state_dict(), |
|
|
230 |
'optimizer': self.optimizer.state_dict()} |
|
|
231 |
torch.save(check_point, os.path.join(self.backup_dir, self.descriptor+'_ep_'+str(epoch))) |
|
|
232 |
|
|
|
233 |
def test(self, dataloader): |
|
|
234 |
epoch = self.best_valid_epoch |
|
|
235 |
if self.data_parallel: |
|
|
236 |
model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device).eval() |
|
|
237 |
else: |
|
|
238 |
model = self.custom_net(**self.custom_net_args).to(self.device).eval() |
|
|
239 |
params_path = os.path.join(self.params_dir,self.descriptor) |
|
|
240 |
print('For test set predictions, loading model params from params_path=',params_path) |
|
|
241 |
check_point = torch.load(params_path) |
|
|
242 |
model.load_state_dict(check_point['params']) |
|
|
243 |
with torch.no_grad(): |
|
|
244 |
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False) |
|
|
245 |
self.eval_results_test = evaluate.evaluate_all(self.eval_results_test, epoch, |
|
|
246 |
self.label_meanings, gr_truth_epoch, pred_epoch) |
|
|
247 |
self.plot_roc_and_pr_curves('test', epoch, pred_epoch, gr_truth_epoch) |
|
|
248 |
self.save_all_pred_probs('test', epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch) |
|
|
249 |
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Test Loss', epoch_loss)) |
|
|
250 |
|
|
|
251 |
def iterate_through_batches(self, model, dataloader, epoch, training): |
|
|
252 |
epoch_loss = 0 |
|
|
253 |
|
|
|
254 |
#Initialize numpy arrays for storing results. examples x labels |
|
|
255 |
#Do NOT use concatenation, or else you will have memory fragmentation. |
|
|
256 |
num_examples = len(dataloader.dataset) |
|
|
257 |
num_labels = len(self.label_meanings) |
|
|
258 |
pred_epoch = np.zeros([num_examples,num_labels]) |
|
|
259 |
gr_truth_epoch = np.zeros([num_examples,num_labels]) |
|
|
260 |
volume_accs_epoch = np.empty(num_examples,dtype='U32') #need to use U32 to allow string of length 32 |
|
|
261 |
|
|
|
262 |
for batch_idx, batch in enumerate(dataloader): |
|
|
263 |
data, gr_truth = self.move_data_to_device(batch) |
|
|
264 |
self.optimizer.zero_grad() |
|
|
265 |
if training: |
|
|
266 |
out = model(data) |
|
|
267 |
else: |
|
|
268 |
with torch.set_grad_enabled(False): |
|
|
269 |
out = model(data) |
|
|
270 |
loss = self.loss_func(out, gr_truth) |
|
|
271 |
if training: |
|
|
272 |
loss.backward() |
|
|
273 |
self.optimizer.step() |
|
|
274 |
|
|
|
275 |
epoch_loss += loss.item() |
|
|
276 |
torch.cuda.empty_cache() |
|
|
277 |
|
|
|
278 |
#Save predictions and ground truth across batches |
|
|
279 |
pred = self.sigmoid(out.data).detach().cpu().numpy() |
|
|
280 |
gr_truth = gr_truth.detach().cpu().numpy() |
|
|
281 |
|
|
|
282 |
start_row = batch_idx*self.batch_size |
|
|
283 |
stop_row = min(start_row + self.batch_size, num_examples) |
|
|
284 |
pred_epoch[start_row:stop_row,:] = pred #pred_epoch is e.g. [25355,80] and pred is e.g. [1,80] for a batch size of 1 |
|
|
285 |
gr_truth_epoch[start_row:stop_row,:] = gr_truth #gr_truth_epoch has same shape as pred_epoch |
|
|
286 |
volume_accs_epoch[start_row:stop_row] = batch['volume_acc'] #volume_accs_epoch stores the volume accessions in the order they were used |
|
|
287 |
|
|
|
288 |
#the following line to empty the cache is necessary in order to |
|
|
289 |
#reduce memory usage and avoid OOM error: |
|
|
290 |
torch.cuda.empty_cache() |
|
|
291 |
return epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch |
|
|
292 |
|
|
|
293 |
def move_data_to_device(self, batch): |
|
|
294 |
"""Move data and ground truth to device.""" |
|
|
295 |
assert self.dataset_args['crop_type'] == 'single' |
|
|
296 |
if self.dataset_args['crop_type'] == 'single': |
|
|
297 |
data = batch['data'].to(self.device) |
|
|
298 |
|
|
|
299 |
#Ground truth to device |
|
|
300 |
gr_truth = batch['gr_truth'].to(self.device) |
|
|
301 |
return data, gr_truth |
|
|
302 |
|
|
|
303 |
def plot_roc_and_pr_curves(self, setname, epoch, pred_epoch, gr_truth_epoch): |
|
|
304 |
outdir = os.path.join(self.results_dir,'curves') |
|
|
305 |
if not os.path.isdir(outdir): |
|
|
306 |
os.mkdir(outdir) |
|
|
307 |
evaluate.plot_roc_curve_multi_class(label_meanings=self.label_meanings, |
|
|
308 |
y_test=gr_truth_epoch, y_score=pred_epoch, |
|
|
309 |
outdir = outdir, setname = setname, epoch = epoch) |
|
|
310 |
evaluate.plot_pr_curve_multi_class(label_meanings=self.label_meanings, |
|
|
311 |
y_test=gr_truth_epoch, y_score=pred_epoch, |
|
|
312 |
outdir = outdir, setname = setname, epoch = epoch) |
|
|
313 |
|
|
|
314 |
def save_all_pred_probs(self, setname, epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch): |
|
|
315 |
outdir = os.path.join(self.results_dir,'pred_probs') |
|
|
316 |
if not os.path.isdir(outdir): |
|
|
317 |
os.mkdir(outdir) |
|
|
318 |
(pd.DataFrame(pred_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_predprob_ep'+str(epoch)+'.csv')) |
|
|
319 |
(pd.DataFrame(gr_truth_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_grtruth_ep'+str(epoch)+'.csv')) |
|
|
320 |
|
|
|
321 |
def save_evals(self, epoch): |
|
|
322 |
evaluate.save(self.eval_results_valid, self.results_dir, self.descriptor+'_valid') |
|
|
323 |
if self.use_test_set: evaluate.save(self.eval_results_test, self.results_dir, self.descriptor+'_test') |
|
|
324 |
evaluate.plot_learning_curves(self.train_loss, self.valid_loss, self.results_dir, self.descriptor) |
|
|
325 |
|
|
|
326 |
def save_final_summary(self): |
|
|
327 |
evaluate.save_final_summary(self.eval_results_valid, self.best_valid_epoch, 'valid', self.results_dir) |
|
|
328 |
if self.use_test_set: evaluate.save_final_summary(self.eval_results_test, self.best_valid_epoch, 'test', self.results_dir) |
|
|
329 |
evaluate.clean_up_output_files(self.best_valid_epoch, self.results_dir) |
|
|
330 |
|