Diff of /eval_surv.py [000000] .. [405115]

Switch to unified view

a b/eval_surv.py
1
from __future__ import print_function
2
3
import argparse
4
import pdb
5
import os
6
import math
7
import sys
8
9
# internal imports
10
from utils.file_utils import save_pkl, load_pkl
11
from utils.utils import *
12
from utils.core_utils import train, eval_model
13
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
14
from datasets.dataset_survival import Generic_WSI_Survival_Dataset, Generic_MIL_Survival_Dataset
15
16
# pytorch imports
17
import torch
18
from torch.utils.data import DataLoader, sampler
19
import torch.nn as nn
20
import torch.nn.functional as F
21
22
import pandas as pd
23
import numpy as np
24
25
from timeit import default_timer as timer
26
27
28
def main(args):
29
    # create results directory if necessary
30
    if not os.path.isdir(args.results_dir):
31
        os.mkdir(args.results_dir)
32
33
    if args.k_start == -1:
34
        start = 0
35
    else:
36
        start = args.k_start
37
    if args.k_end == -1:
38
        end = args.k
39
    else:
40
        end = args.k_end
41
42
    val_cindex = []
43
    folds = np.arange(start, end)
44
45
    for i in folds:
46
        start = timer()
47
        seed_torch(args.seed)
48
49
        train_dataset, val_dataset = dataset.return_splits(from_id=False, csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
50
        
51
        print('training: {}, validation: {}'.format(len(train_dataset), len(val_dataset)))
52
        datasets = (train_dataset, val_dataset)
53
        
54
        if 'omic' in args.mode:
55
            args.omic_input_dim = train_dataset.genomic_features.shape[1]
56
            print("Genomic Dimension", args.omic_input_dim)
57
58
        val_latest, cindex_latest = eval_model(datasets, i, args)
59
        val_cindex.append(cindex_latest)
60
61
        #write results to pkl
62
        save_pkl(os.path.join(args.results_dir, 'split_val_{}_results.pkl'.format(i)), val_latest)
63
        end = timer()
64
        print('Fold %d Time: %f seconds' % (i, end - start))
65
66
    if len(folds) != args.k: save_name = 'summary_partial_{}_{}.csv'.format(start, end)
67
    else: save_name = 'summary.csv'
68
    results_df = pd.DataFrame({'folds': folds, 'val_cindex': val_cindex})
69
    results_df.to_csv(os.path.join(args.results_dir, 'summary.csv'))
70
71
# Training settings
72
parser = argparse.ArgumentParser(description='Configurations for MMF Training')
73
parser.add_argument('--data_root_dir', type=str, default='/media/ssd1/pan-cancer', help='data directory')
74
parser.add_argument('--which_splits', type=str, default='5foldcv', help='Path to splits directory.')
75
parser.add_argument('--split_dir', type=str, help='Set of splits to use for each cancer type.')
76
parser.add_argument('--mode', type=str, default='omic')
77
parser.add_argument('--model_type', type=str, default='clam', help='type of model (attention_mil | max_net | mm_attention_mil)')
78
79
parser.add_argument('--max_epochs', type=int, default=20, help='maximum number of epochs to train (default: 20)')
80
parser.add_argument('--lr', type=float, default=2e-4, help='learning rate (default: 0.0001)')
81
parser.add_argument('--label_frac', type=float, default=1.0, help='fraction of training labels (default: 1.0)')
82
parser.add_argument('--bag_weight', type=float, default=0.7, help='clam: weight coefficient for bag-level loss (default: 0.7)')
83
parser.add_argument('--reg', type=float, default=1e-5, help='weight decay (default: 1e-5)')
84
parser.add_argument('--seed', type=int, default=1, help='random seed for reproducible experiment (default: 1)')
85
parser.add_argument('--k', type=int, default=5, help='number of folds (default: 10)')
86
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
87
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
88
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)')
89
parser.add_argument('--log_data', action='store_true', default=True, help='log data using tensorboard')
90
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool')
91
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping')
92
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam')
93
parser.add_argument('--drop_out', action='store_true', default=True, help='enabel dropout (p=0.25)')
94
parser.add_argument('--inst_loss', type=str, choices=['svm', 'ce', None], default=None, help='instance-level clustering loss function (default: None)')
95
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce', 'ce_surv', 'nll_surv', 'cox_surv'], default='nll_surv', help='slide-level classification loss function (default: ce)')
96
parser.add_argument('--alpha_surv', type=float, default=0.0, help='How much to weigh uncensored patients')
97
parser.add_argument('--reg_type', type=str, choices=['None', 'omic', 'pathomic'], default='None', help='Reg Type (default: None)')
98
parser.add_argument('--lambda_reg', type=float, default=1e-4, help='Regularization Strength')
99
parser.add_argument('--weighted_sample', action='store_true', default=True, help='enable weighted sampling')
100
parser.add_argument('--model_size_wsi', type=str, default='small', help='Size of AMIL model.')
101
parser.add_argument('--model_size_omic', type=str, default='small', help='Size of SNN Model.')
102
parser.add_argument('--gc', type=int, default=1, help='gradient accumulation step')
103
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size')
104
parser.add_argument('--gate_path', action='store_true', default=False, help='Enable feature gating in MMF layer.')
105
parser.add_argument('--gate_omic', action='store_true', default=False, help='Enable feature gating in MMF layer.')
106
parser.add_argument('--fusion', type=str, default='tensor', help='Which fusion mechanism to use.')
107
parser.add_argument('--overwrite', action='store_true', default=False, help='Current experiment results already exists. Redo?')
108
parser.add_argument('--apply_mad', action='store_true', default=True, help='Use genes with median absolute deviation.')
109
parser.add_argument('--task', type=str, default='survival', help='Which task.')
110
args = parser.parse_args()
111
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
113
114
### Creates Custom Experiment Code
115
exp_code = '_'.join(args.split_dir.split('_')[:2])
116
dataset_path = 'dataset_csv'
117
param_code = ''
118
119
if args.model_type == 'attention_mil':
120
  param_code += 'WSI'
121
elif args.model_type == 'max_net':
122
  param_code += 'SNN'
123
elif args.model_type == 'mm_attention_mil' and args.fusion == 'tensor':
124
  param_code += 'MMF'
125
else:
126
  raise NotImplementedError
127
128
if 'small' in args.model_size_omic:
129
  param_code += 'sm'
130
131
param_code += '_%s' % args.bag_loss
132
133
if 'mm_' in args.model_type:
134
  param_code += '_g'
135
  if args.gate_path:
136
    param_code += '1'
137
  else:
138
    param_code += '0'
139
140
  if args.gate_omic:
141
    param_code += '1'
142
  else:
143
    param_code += '0'
144
145
param_code += '_a%s' % str(args.alpha_surv)
146
147
if args.lr != 2e-4:
148
  param_code += '_lr%s' % format(args.lr, '.0e')
149
150
if args.reg_type != 'None':
151
  param_code += '_reg%s' % format(args.lambda_reg, '.0e')
152
153
param_code += '_%s' % args.which_splits.split("_")[0]
154
155
if args.gc != 1:
156
  param_code += '_gc%s' % str(args.gc)
157
158
if args.apply_mad:
159
  param_code += '_mad'
160
  #dataset_path += '_mad'
161
  
162
args.exp_code = exp_code + "_" + param_code
163
164
### task
165
if args.task == 'survival':
166
  args.task = '_'.join(args.split_dir.split('_')[:2]) + '_survival'
167
print("Experiment Name:", exp_code)
168
169
170
def seed_torch(seed=7):
171
    import random
172
    random.seed(seed)
173
    os.environ['PYTHONHASHSEED'] = str(seed)
174
    np.random.seed(seed)
175
    torch.manual_seed(seed)
176
    if device.type == 'cuda':
177
        torch.cuda.manual_seed(seed)
178
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
179
    torch.backends.cudnn.benchmark = False
180
    torch.backends.cudnn.deterministic = True
181
182
seed_torch(args.seed)
183
184
encoding_size = 1024
185
settings = {'num_splits': args.k, 
186
            'k_start': args.k_start,
187
            'k_end': args.k_end,
188
            'task': args.task,
189
            'max_epochs': args.max_epochs, 
190
            'results_dir': args.results_dir, 
191
            'lr': args.lr,
192
            'experiment': args.exp_code,
193
            'reg': args.reg,
194
            'label_frac': args.label_frac,
195
            'inst_loss': args.inst_loss,
196
            'bag_loss': args.bag_loss,
197
            'bag_weight': args.bag_weight,
198
            'seed': args.seed,
199
            'model_type': args.model_type,
200
            'model_size_wsi': args.model_size_wsi,
201
            'model_size_omic': args.model_size_omic,
202
            "use_drop_out": args.drop_out,
203
            'weighted_sample': args.weighted_sample,
204
            'gc': args.gc,
205
            'opt': args.opt}
206
207
print('\nLoad Dataset')
208
if args.task == 'tcga_blca_survival':
209
  args.n_classes = 4
210
  proj = '_'.join(args.task.split('_')[:2])
211
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv' % (dataset_path, proj),
212
                                           mode = args.mode,
213
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_bladder_20x_features'),
214
                                           shuffle = False, 
215
                                           seed = args.seed, 
216
                                           print_info = True,
217
                                           patient_strat= False,
218
                                           n_bins=4,
219
                                           label_col = 'survival_months',
220
                                           ignore=[])
221
elif args.task == 'tcga_brca_survival':
222
  args.n_classes = 4
223
  proj = '_'.join(args.task.split('_')[:2])
224
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
225
                                           mode = args.mode,
226
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_breast_20x_features'),
227
                                           shuffle = False, 
228
                                           seed = args.seed, 
229
                                           print_info = True,
230
                                           patient_strat= False,
231
                                           n_bins=4,
232
                                           label_col = 'survival_months',
233
                                           ignore=[])
234
elif args.task == 'tcga_coadread_survival':
235
  args.n_classes = 4
236
  proj = '_'.join(args.task.split('_')[:2])
237
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
238
                                           mode = args.mode,
239
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_coadread_20x_features'),
240
                                           shuffle = False, 
241
                                           seed = args.seed, 
242
                                           print_info = True,
243
                                           patient_strat= False,
244
                                           n_bins=4,
245
                                           label_col = 'survival_months',
246
                                           ignore=[])
247
elif args.task == 'tcga_gbmlgg_survival':
248
    args.n_classes = 4
249
    dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/tcga_gbmlgg_all.csv' % dataset_path,
250
                                           mode = args.mode,
251
                                           data_dir={'ASTR': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'),
252
                                                     'AASTR': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'),
253
                                                     'ODG': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'),
254
                                                     'OAST': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'),
255
                                                     'AOAST': os.path.join(args.data_root_dir,'tcga_lgg_20x_features'),
256
                                                     'GBM': os.path.join(args.data_root_dir,'tcga_gbm_20x_features'),},
257
                                           shuffle = False, 
258
                                           seed = args.seed, 
259
                                           print_info = True,
260
                                           patient_strat= False,
261
                                           n_bins=4,
262
                                           label_col = 'survival_months',
263
                                           ignore=[])
264
elif args.task == 'tcga_hnsc_survival':
265
  args.n_classes = 4
266
  proj = '_'.join(args.task.split('_')[:2])
267
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
268
                                           mode = args.mode,
269
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_hnsc_20x_features'),
270
                                           shuffle = False, 
271
                                           seed = args.seed, 
272
                                           print_info = True,
273
                                           patient_strat= False,
274
                                           n_bins=4,
275
                                           label_col = 'survival_months',
276
                                           ignore=[])
277
elif args.task == 'tcga_kirc_survival':
278
  args.n_classes = 4
279
  proj = '_'.join(args.task.split('_')[:2])
280
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
281
                                           mode = args.mode,
282
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_kidney_20x_features'),
283
                                           shuffle = False, 
284
                                           seed = args.seed, 
285
                                           print_info = True,
286
                                           patient_strat= False,
287
                                           n_bins=4,
288
                                           label_col = 'survival_months',
289
                                           ignore=[])
290
elif args.task == 'tcga_kirp_survival':
291
  args.n_classes = 4
292
  proj = '_'.join(args.task.split('_')[:2])
293
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
294
                                           mode = args.mode,
295
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_kidney_20x_features'),
296
                                           shuffle = False, 
297
                                           seed = args.seed, 
298
                                           print_info = True,
299
                                           patient_strat= False,
300
                                           n_bins=4,
301
                                           label_col = 'survival_months',
302
                                           ignore=[])
303
elif args.task == 'tcga_lihc_survival':
304
  args.n_classes = 4
305
  proj = '_'.join(args.task.split('_')[:2])
306
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
307
                                           mode = args.mode,
308
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_liver_20x_features'),
309
                                           shuffle = False, 
310
                                           seed = args.seed, 
311
                                           print_info = True,
312
                                           patient_strat= False,
313
                                           n_bins=4,
314
                                           label_col = 'survival_months',
315
                                           ignore=[])
316
elif args.task == 'tcga_luad_survival':
317
  args.n_classes = 4
318
  proj = '_'.join(args.task.split('_')[:2])
319
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
320
                                           mode = args.mode,
321
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_lung_20x_features'),
322
                                           shuffle = False, 
323
                                           seed = args.seed, 
324
                                           print_info = True,
325
                                           patient_strat= False,
326
                                           n_bins=4,
327
                                           label_col = 'survival_months',
328
                                           ignore=[])
329
elif args.task == 'tcga_lusc_survival':
330
  args.n_classes = 4
331
  proj = '_'.join(args.task.split('_')[:2])
332
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
333
                                           mode = args.mode,
334
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_lung_20x_features'),
335
                                           shuffle = False, 
336
                                           seed = args.seed, 
337
                                           print_info = True,
338
                                           patient_strat= False,
339
                                           n_bins=4,
340
                                           label_col = 'survival_months',
341
                                           ignore=[])
342
elif args.task == 'tcga_paad_survival':
343
  args.n_classes = 4
344
  proj = '_'.join(args.task.split('_')[:2])
345
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
346
                                           mode = args.mode,
347
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_pancreas_20x_features'),
348
                                           shuffle = False, 
349
                                           seed = args.seed, 
350
                                           print_info = True,
351
                                           patient_strat= False,
352
                                           n_bins=4,
353
                                           label_col = 'survival_months',
354
                                           ignore=[])
355
elif args.task == 'tcga_skcm_survival':
356
  args.n_classes = 4
357
  proj = '_'.join(args.task.split('_')[:2])
358
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
359
                                           mode = args.mode,
360
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_skin_20x_features'),
361
                                           shuffle = False, 
362
                                           seed = args.seed, 
363
                                           print_info = True,
364
                                           patient_strat= False,
365
                                           n_bins=4,
366
                                           label_col = 'survival_months',
367
                                           ignore=[])
368
elif args.task == 'tcga_stad_survival':
369
  args.n_classes = 4
370
  proj = '_'.join(args.task.split('_')[:2])
371
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
372
                                           mode = args.mode,
373
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_stomach_20x_features'),
374
                                           shuffle = False, 
375
                                           seed = args.seed, 
376
                                           print_info = True,
377
                                           patient_strat= False,
378
                                           n_bins=4,
379
                                           label_col = 'survival_months',
380
                                           ignore=[])
381
elif args.task == 'tcga_ucec_survival':
382
  args.n_classes = 4
383
  proj = '_'.join(args.task.split('_')[:2])
384
  dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all.csv'  % (dataset_path, proj),
385
                                           mode = args.mode,
386
                                           data_dir= os.path.join(args.data_root_dir, 'tcga_endometrial_20x_features'),
387
                                           shuffle = False, 
388
                                           seed = args.seed, 
389
                                           print_info = True,
390
                                           patient_strat= False,
391
                                           n_bins=4,
392
                                           label_col = 'survival_months',
393
                                           ignore=[])
394
else:
395
  raise NotImplementedError
396
397
if isinstance(dataset, Generic_MIL_Survival_Dataset):
398
    args.task_type ='survival'
399
else:
400
    raise NotImplementedError
401
402
if not os.path.isdir(args.results_dir):
403
    os.mkdir(args.results_dir)
404
405
### GET RID OF WHICH_SPLITS IF U WANT TO MAKE THE RESULTS FOLDER LESS CLUTTERRED
406
args.results_dir = os.path.join(args.results_dir, args.which_splits, param_code, str(args.exp_code) + '_s{}'.format(args.seed))
407
if not os.path.isdir(args.results_dir):
408
    os.makedirs(args.results_dir)
409
410
if ('summary.csv' in os.listdir(args.results_dir)) and (not args.overwrite):
411
  print("Exp Code <%s> already exists! Exiting script." % args.exp_code)
412
  sys.exit()
413
414
if args.split_dir is None:
415
    args.split_dir = os.path.join('./splits', args.task+'_{}'.format(int(args.label_frac*100)))
416
else:
417
    args.split_dir = os.path.join('./splits', args.which_splits, args.split_dir)
418
419
print("split_dir", args.split_dir)
420
421
assert os.path.isdir(args.split_dir)
422
423
settings.update({'split_dir': args.split_dir})
424
425
426
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
427
    print(settings, file=f)
428
f.close()
429
430
print("################# Settings ###################")
431
for key, val in settings.items():
432
    print("{}:  {}".format(key, val))        
433
434
if __name__ == "__main__":
435
436
    start = timer()
437
    results = main(args)
438
    end = timer()
439
    print("finished!")
440
    print("end script")
441
    print('Script Time: %f seconds' % (end - start))