a b/main.py
1
from __future__ import print_function
2
import numpy as np
3
import torch_geometric
4
5
import argparse
6
import pdb
7
import os
8
import math
9
import sys
10
from timeit import default_timer as timer
11
12
import numpy as np
13
import pandas as pd
14
15
### Internal Imports
16
from datasets.dataset_survival import Generic_WSI_Survival_Dataset, Generic_MIL_Survival_Dataset
17
from utils.file_utils import save_pkl, load_pkl
18
from utils.core_utils import train
19
from utils.utils import get_custom_exp_code
20
21
### PyTorch Imports
22
import torch
23
import torch.nn as nn
24
import torch.nn.functional as F
25
from torch.utils.data import DataLoader, sampler
26
27
28
def main(args):
29
    #### Create Results Directory
30
    if not os.path.isdir(args.results_dir):
31
        os.mkdir(args.results_dir)
32
33
    if args.k_start == -1:
34
        start = 0
35
    else:
36
        start = args.k_start
37
    if args.k_end == -1:
38
        end = args.k
39
    else:
40
        end = args.k_end
41
42
    latest_val_cindex = []
43
    folds = np.arange(start, end)
44
45
    ### Start 5-Fold CV Evaluation.
46
    for i in folds:
47
        start = timer()
48
        seed_torch(args.seed)
49
        results_pkl_path = os.path.join(args.results_dir, 'split_latest_val_{}_results.pkl'.format(i))
50
        if os.path.isfile(results_pkl_path) and (not args.overwrite):
51
            print("Skipping Split %d" % i)
52
            continue
53
54
        ### Gets the Train + Val Dataset Loader.
55
        train_dataset, val_dataset = dataset.return_splits(from_id=False, 
56
                csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
57
        train_dataset.set_split_id(split_id=i)
58
        val_dataset.set_split_id(split_id=i)
59
        
60
        #pdb.set_trace()
61
        print('training: {}, validation: {}'.format(len(train_dataset), len(val_dataset)))
62
        datasets = (train_dataset, val_dataset)
63
        
64
        ### Specify the input dimension size if using genomic features.
65
        if 'omic' in args.mode or args.mode == 'cluster' or args.mode == 'graph' or args.mode == 'pyramid':
66
            args.omic_input_dim = train_dataset.genomic_features.shape[1]
67
            print("Genomic Dimension", args.omic_input_dim)
68
        elif 'coattn' in args.mode:
69
            args.omic_sizes = train_dataset.omic_sizes
70
            print('Genomic Dimensions', args.omic_sizes)
71
        else:
72
            args.omic_input_dim = 0
73
74
        ### Run Train-Val on Survival Task.
75
        if args.task_type == 'survival':
76
            val_latest, cindex_latest = train(datasets, i, args)
77
            latest_val_cindex.append(cindex_latest)
78
79
        ### Write Results for Each Split to PKL
80
        save_pkl(results_pkl_path, val_latest)
81
        end = timer()
82
        print('Fold %d Time: %f seconds' % (i, end - start))
83
84
    ### Finish 5-Fold CV Evaluation.
85
    if args.task_type == 'survival':
86
        results_latest_df = pd.DataFrame({'folds': folds, 'val_cindex': latest_val_cindex})
87
88
    if len(folds) != args.k:
89
        save_name = 'summary_partial_{}_{}.csv'.format(start, end)
90
    else:
91
        save_name = 'summary.csv'
92
93
    results_latest_df.to_csv(os.path.join(args.results_dir, 'summary_latest.csv'))
94
95
### Training settings
96
parser = argparse.ArgumentParser(description='Configurations for Survival Analysis on TCGA Data.')
97
### Checkpoint + Misc. Pathing Parameters
98
parser.add_argument('--data_root_dir',   type=str, default='path/to/data_root_dir', help='Data directory to WSI features (extracted via CLAM')
99
parser.add_argument('--seed',            type=int, default=1, help='Random seed for reproducible experiment (default: 1)')
100
parser.add_argument('--k',               type=int, default=5, help='Number of folds (default: 5)')
101
parser.add_argument('--k_start',         type=int, default=-1, help='Start fold (Default: -1, last fold)')
102
parser.add_argument('--k_end',           type=int, default=-1, help='End fold (Default: -1, first fold)')
103
parser.add_argument('--results_dir',     type=str, default='./results_new', help='Results directory (Default: ./results)')
104
parser.add_argument('--which_splits',    type=str, default='5foldcv', help='Which splits folder to use in ./splits/ (Default: ./splits/5foldcv')
105
parser.add_argument('--split_dir',       type=str, default='tcga_blca', help='Which cancer type within ./splits/<which_splits> to use for training. Used synonymously for "task" (Default: tcga_blca_100)')
106
parser.add_argument('--log_data',        action='store_true', default=True, help='Log data using tensorboard')
107
parser.add_argument('--overwrite',       action='store_true', default=False, help='Whether or not to overwrite experiments (if already ran)')
108
109
### Model Parameters.
110
parser.add_argument('--model_type',      type=str, default='mcat', help='Type of model (Default: mcat)')
111
parser.add_argument('--mode',            type=str, choices=['omic', 'path', 'pathomic', 'pathomic_fast', 'cluster', 'coattn'], default='coattn', help='Specifies which modalities to use / collate function in dataloader.')
112
parser.add_argument('--fusion',          type=str, choices=['None', 'concat', 'bilinear'], default='None', help='Type of fusion. (Default: concat).')
113
parser.add_argument('--apply_sig',       action='store_true', default=False, help='Use genomic features as signature embeddings.')
114
parser.add_argument('--apply_sigfeats',  action='store_true', default=False, help='Use genomic features as tabular features.')
115
parser.add_argument('--drop_out',        action='store_true', default=True, help='Enable dropout (p=0.25)')
116
parser.add_argument('--model_size_wsi',  type=str, default='small', help='Network size of AMIL model')
117
parser.add_argument('--model_size_omic', type=str, default='small', help='Network size of SNN model')
118
119
parser.add_argument('--n_classes', type=int, default=4)
120
121
122
### PORPOISE
123
parser.add_argument('--apply_mutsig', action='store_true', default=False)
124
parser.add_argument('--gate_path', action='store_true', default=False)
125
parser.add_argument('--gate_omic', action='store_true', default=False)
126
parser.add_argument('--scale_dim1', type=int, default=8)
127
parser.add_argument('--scale_dim2', type=int, default=8)
128
parser.add_argument('--skip', action='store_true', default=False)
129
parser.add_argument('--dropinput', type=float, default=0.0)
130
parser.add_argument('--path_input_dim', type=int, default=1024)
131
parser.add_argument('--use_mlp', action='store_true', default=False)
132
133
134
### Optimizer Parameters + Survival Loss Function
135
parser.add_argument('--opt',             type=str, choices = ['adam', 'sgd'], default='adam')
136
parser.add_argument('--batch_size',      type=int, default=1, help='Batch Size (Default: 1, due to varying bag sizes)')
137
parser.add_argument('--gc',              type=int, default=32, help='Gradient Accumulation Step.')
138
parser.add_argument('--max_epochs',      type=int, default=20, help='Maximum number of epochs to train (default: 20)')
139
parser.add_argument('--lr',              type=float, default=2e-4, help='Learning rate (default: 0.0001)')
140
parser.add_argument('--bag_loss',        type=str, choices=['svm', 'ce', 'ce_surv', 'nll_surv'], default='nll_surv', help='slide-level classification loss function (default: ce)')
141
parser.add_argument('--label_frac',      type=float, default=1.0, help='fraction of training labels (default: 1.0)')
142
parser.add_argument('--reg',             type=float, default=1e-5, help='L2-regularization weight decay (default: 1e-5)')
143
parser.add_argument('--alpha_surv',      type=float, default=0.0, help='How much to weigh uncensored patients')
144
parser.add_argument('--reg_type',        type=str, choices=['None', 'omic', 'pathomic'], default='None', help='Which network submodules to apply L1-Regularization (default: None)')
145
parser.add_argument('--lambda_reg',      type=float, default=1e-5, help='L1-Regularization Strength (Default 1e-4)')
146
parser.add_argument('--weighted_sample', action='store_true', default=True, help='Enable weighted sampling')
147
parser.add_argument('--early_stopping',  action='store_true', default=False, help='Enable early stopping')
148
149
### CLAM-Specific Parameters
150
parser.add_argument('--bag_weight',      type=float, default=0.7, help='clam: weight coefficient for bag-level loss (default: 0.7)')
151
parser.add_argument('--testing',         action='store_true', default=False, help='debugging tool')
152
153
args = parser.parse_args()
154
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
156
### Creates Experiment Code from argparse + Folder Name to Save Results
157
args = get_custom_exp_code(args)
158
args.task = '_'.join(args.split_dir.split('_')[:2]) + '_survival'
159
print("Experiment Name:", args.exp_code)
160
161
### Sets Seed for reproducible experiments.
162
def seed_torch(seed=7):
163
    import random
164
    random.seed(seed)
165
    os.environ['PYTHONHASHSEED'] = str(seed)
166
    np.random.seed(seed)
167
    torch.manual_seed(seed)
168
    if device.type == 'cuda':
169
        torch.cuda.manual_seed(seed)
170
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
171
    torch.backends.cudnn.benchmark = False
172
    torch.backends.cudnn.deterministic = True
173
174
seed_torch(args.seed)
175
176
encoding_size = 1024
177
settings = {'num_splits': args.k, 
178
            'k_start': args.k_start,
179
            'k_end': args.k_end,
180
            'task': args.task,
181
            'max_epochs': args.max_epochs, 
182
            'results_dir': args.results_dir, 
183
            'lr': args.lr,
184
            'experiment': args.exp_code,
185
            'reg': args.reg,
186
            'label_frac': args.label_frac,
187
            'bag_loss': args.bag_loss,
188
            #'bag_weight': args.bag_weight,
189
            'seed': args.seed,
190
            'model_type': args.model_type,
191
            'model_size_wsi': args.model_size_wsi,
192
            'model_size_omic': args.model_size_omic,
193
            "use_drop_out": args.drop_out,
194
            'weighted_sample': args.weighted_sample,
195
            'gc': args.gc,
196
            'opt': args.opt}
197
print('\nLoad Dataset')
198
199
if 'survival' in args.task:
200
    study = '_'.join(args.task.split('_')[:2])
201
    if study == 'tcga_kirc' or study == 'tcga_kirp':
202
        combined_study = 'tcga_kidney'
203
    elif study == 'tcga_luad' or study == 'tcga_lusc':
204
        combined_study = 'tcga_lung'
205
    else:
206
        combined_study = study
207
    
208
    study_dir = '%s_20x_features' % combined_study
209
210
    dataset = Generic_MIL_Survival_Dataset(csv_path = './%s/%s_all_clean.csv.zip' % (args.dataset_path, study),
211
                                           mode = args.mode,
212
                                           apply_sig = args.apply_sig,
213
                                           data_dir= os.path.join(args.data_root_dir, study_dir),
214
                                           shuffle = False, 
215
                                           seed = args.seed, 
216
                                           print_info = True,
217
                                           patient_strat= False,
218
                                           n_bins=4,
219
                                           label_col = 'survival_months',
220
                                           ignore=[])
221
else:
222
    raise NotImplementedError
223
224
if isinstance(dataset, Generic_MIL_Survival_Dataset):
225
    args.task_type = 'survival'
226
else:
227
    raise NotImplementedError
228
229
### Creates results_dir Directory.
230
if not os.path.isdir(args.results_dir):
231
    os.mkdir(args.results_dir)
232
233
### Appends to the results_dir path: 1) which splits were used for training (e.g. - 5foldcv), and then 2) the parameter code and 3) experiment code
234
args.results_dir = os.path.join(args.results_dir, args.which_splits, args.param_code, str(args.exp_code) + '_s{}'.format(args.seed))
235
if not os.path.isdir(args.results_dir):
236
    os.makedirs(args.results_dir)
237
238
if ('summary_latest.csv' in os.listdir(args.results_dir)) and (not args.overwrite):
239
    print("Exp Code <%s> already exists! Exiting script." % args.exp_code)
240
    sys.exit()
241
242
### Sets the absolute path of split_dir
243
args.split_dir = os.path.join('./splits', args.which_splits, args.split_dir)
244
print("split_dir", args.split_dir)
245
assert os.path.isdir(args.split_dir)
246
settings.update({'split_dir': args.split_dir})
247
248
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
249
    print(settings, file=f)
250
f.close()
251
252
print("################# Settings ###################")
253
for key, val in settings.items():
254
    print("{}:  {}".format(key, val))        
255
256
if __name__ == "__main__":
257
    start = timer()
258
    results = main(args)
259
    end = timer()
260
    print("finished!")
261
    print("end script")
262
    print('Script Time: %f seconds' % (end - start))