Diff of /main_survival.py [000000] .. [0fdc30]

Switch to unified view

a b/main_survival.py
1
from __future__ import print_function
2
3
import argparse
4
import pdb
5
import os
6
import math
7
8
# internal imports
9
from utils.file_utils import save_pkl, load_pkl
10
from utils.utils import *
11
from utils.core_utils_survival import train
12
from datasets.dataset_survival import Generic_WSI_Survival_Dataset, Generic_MIL_Survival_Dataset
13
14
# pytorch imports
15
import torch
16
from torch.utils.data import DataLoader, sampler
17
import torch.nn as nn
18
import torch.nn.functional as F
19
20
import pandas as pd
21
import numpy as np
22
23
24
def main(args):
25
    # create results directory if necessary
26
    if not os.path.isdir(args.results_dir):
27
        os.mkdir(args.results_dir)
28
29
    if args.k_start == -1:
30
        start = 0
31
    else:
32
        start = args.k_start
33
    if args.k_end == -1:
34
        end = args.k
35
    else:
36
        end = args.k_end
37
38
    all_test_cindex = []
39
    all_val_cindex = []
40
    folds = np.arange(start, end)
41
    for i in folds:
42
        seed_torch(args.seed)
43
        train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, 
44
                csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
45
        
46
        datasets = (train_dataset, val_dataset, test_dataset)
47
        results, test_cindex, val_cindex  = train(datasets, i, args)
48
        all_test_cindex.append(test_cindex)
49
        all_val_cindex.append(val_cindex)
50
        #write results to pkl
51
        filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
52
        save_pkl(filename, results)
53
54
    final_df = pd.DataFrame({'folds': folds, 'test_cindex': all_test_cindex, 'val_cindex' : all_val_cindex})
55
56
    if len(folds) != args.k:
57
        save_name = 'summary_partial_{}_{}.csv'.format(start, end)
58
    else:
59
        save_name = 'summary.csv'
60
    final_df.to_csv(os.path.join(args.results_dir, save_name))
61
62
# Generic training settings
63
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
64
parser.add_argument('--data_root_dir', type=str, default=None, 
65
                    help='data directory')
66
parser.add_argument('--max_epochs', type=int, default=200,
67
                    help='maximum number of epochs to train (default: 200)')
68
parser.add_argument('--lr', type=float, default=1e-4,
69
                    help='learning rate (default: 0.0001)')
70
parser.add_argument('--label_frac', type=float, default=1.0,
71
                    help='fraction of training labels (default: 1.0)')
72
parser.add_argument('--reg', type=float, default=1e-5,
73
                    help='weight decay (default: 1e-5)')
74
parser.add_argument('--seed', type=int, default=1, 
75
                    help='random seed for reproducible experiment (default: 1)')
76
parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)')
77
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
78
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
79
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)')
80
parser.add_argument('--split_dir', type=str, default=None, 
81
                    help='manually specify the set of splits to use, ' 
82
                    +'instead of infering from the task and label_frac argument (default: None)')
83
parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard')
84
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool')
85
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping')
86
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam')
87
parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)')
88
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce'], default='ce',
89
                     help='slide-level classification loss function (default: ce)')
90
parser.add_argument('--model_type', type=str, choices=['amil', 'mil'], default='amil', 
91
                    help='type of model (default: amil)')
92
parser.add_argument('--exp_code', type=str, help='experiment code for saving results')
93
parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling')
94
parser.add_argument('--model_size', type=str, choices=['small', 'big','tiny'], default='small', help='size of model, does not affect mil')
95
parser.add_argument('--task', type=str, choices=['task_3_survival_prediction'])
96
parser.add_argument('--csv_path', type=str, default=None, help='Path to csv dataset.')
97
parser.add_argument('--feature_dir', type=str, default=None, help='feature directory')
98
parser.add_argument('--n_iters', type=int, default=16, help='Number of iterations until cox loss is calculated')
99
args = parser.parse_args()
100
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
102
def seed_torch(seed=7):
103
    import random
104
    random.seed(seed)
105
    os.environ['PYTHONHASHSEED'] = str(seed)
106
    np.random.seed(seed)
107
    torch.manual_seed(seed)
108
    if device.type == 'cuda':
109
        torch.cuda.manual_seed(seed)
110
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
111
    torch.backends.cudnn.benchmark = False
112
    torch.backends.cudnn.deterministic = True
113
114
seed_torch(args.seed)
115
116
encoding_size = 1024
117
settings = {'num_splits': args.k, 
118
            'k_start': args.k_start,
119
            'k_end': args.k_end,
120
            'task': args.task,
121
            'max_epochs': args.max_epochs, 
122
            'results_dir': args.results_dir, 
123
            'lr': args.lr,
124
            'experiment': args.exp_code,
125
            'reg': args.reg,
126
            'label_frac': args.label_frac,
127
            'bag_loss': args.bag_loss,
128
            'seed': args.seed,
129
            'model_type': args.model_type,
130
            'model_size': args.model_size,
131
            "use_drop_out": args.drop_out,
132
            'weighted_sample': args.weighted_sample,
133
            'opt': args.opt}
134
135
print('\nLoad Dataset')
136
137
138
if args.task == 'task_3_survival_prediction':
139
140
    if args.csv_path == None:
141
        raise ValueError('Must provide a csv dataset file.')
142
    else:
143
        csv_path = args.csv_path
144
145
    if args.feature_dir is not None:
146
        feature_dir = args.feature_dir
147
    else:
148
        raise ValueError('Must provide feature directory.')
149
150
    dataset = Generic_MIL_Survival_Dataset(csv_path = csv_path,
151
                        data_dir= os.path.join(args.data_root_dir, feature_dir),
152
                        shuffle = False, 
153
                        seed = args.seed, 
154
                        print_info = True,
155
                        label_dict = {'lebt':0, 'tod':1},
156
                        event_col = 'event',
157
                        time_col = 'time',
158
                        patient_strat=True,
159
                        ignore=[])
160
else:
161
    raise NotImplementedError
162
    
163
if not os.path.isdir(args.results_dir):
164
    os.mkdir(args.results_dir)
165
166
args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed))
167
if not os.path.isdir(args.results_dir):
168
    os.mkdir(args.results_dir)
169
170
171
if args.split_dir is None:
172
    raise ValueError('Must provide split_dir folder name.')
173
else:
174
    args.split_dir = os.path.join('splits', '{}'.format(args.split_dir))
175
176
print('split_dir: ', args.split_dir)
177
assert os.path.isdir(args.split_dir)
178
179
settings.update({'split_dir': args.split_dir})
180
181
182
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
183
    print(settings, file=f)
184
f.close()
185
186
print("################# Settings ###################")
187
for key, val in settings.items():
188
    print("{}:  {}".format(key, val))        
189
190
if __name__ == "__main__":
191
    results = main(args)
192
    print("finished!")
193
    print("end script")
194
195