a b/main_mtl_concat.py
1
from __future__ import print_function
2
3
import argparse
4
import pdb
5
import os
6
import math
7
8
# internal imports
9
from utils.file_utils import save_pkl, load_pkl
10
from utils.utils import *
11
from utils.core_utils_mtl_concat import train
12
from datasets.dataset_mtl_concat import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset
13
14
# pytorch imports
15
import torch
16
from torch.utils.data import DataLoader, sampler
17
import torch.nn as nn
18
import torch.nn.functional as F
19
20
import pandas as pd
21
import numpy as np
22
23
def main(args):
24
    # create results directory if necessary
25
    if not os.path.isdir(args.results_dir):
26
        os.mkdir(args.results_dir)
27
28
    if args.k_start == -1:
29
        start = 0
30
    else:
31
        start = args.k_start
32
    if args.k_end == -1:
33
        end = args.k
34
    else:
35
        end = args.k_end
36
37
    all_cls_test_auc = []
38
    all_cls_val_auc = []
39
    all_cls_test_acc = []
40
    all_cls_val_acc = []
41
    
42
    all_site_test_auc = []
43
    all_site_val_auc = []
44
    all_site_test_acc = []
45
    all_site_val_acc = []
46
    folds = np.arange(start, end)
47
    for i in folds:
48
        seed_torch(args.seed)
49
        train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False, 
50
                csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
51
        
52
        print('training: {}, validation: {}, testing: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset)))
53
        datasets = (train_dataset, val_dataset, test_dataset)
54
        results, cls_test_auc, cls_val_auc, cls_test_acc, cls_val_acc, site_test_auc, site_val_auc, site_test_acc, site_val_acc  = train(datasets, i, args)
55
        all_cls_test_auc.append(cls_test_auc)
56
        all_cls_val_auc.append(cls_val_auc)
57
        all_cls_test_acc.append(cls_test_acc)
58
        all_cls_val_acc.append(cls_val_acc)
59
        
60
        all_site_test_auc.append(site_test_auc)
61
        all_site_val_auc.append(site_val_auc)
62
        all_site_test_acc.append(site_test_acc)
63
        all_site_val_acc.append(site_val_acc)
64
        #write results to pkl
65
        filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
66
        save_pkl(filename, results)
67
68
    final_df = pd.DataFrame({'folds': folds, 'cls_test_auc': all_cls_test_auc, 
69
        'cls_val_auc': all_cls_val_auc, 'cls_test_acc': all_cls_test_acc, 'cls_val_acc' : all_cls_val_acc, 
70
        'site_test_auc': all_site_test_auc, 
71
        'site_val_auc': all_site_val_auc, 'site_test_acc': all_site_test_acc, 'site_val_acc' : all_site_val_acc})
72
73
74
    if len(folds) != args.k:
75
        save_name = 'summary_partial_{}_{}.csv'.format(start, end)
76
    else:
77
        save_name = 'summary.csv'
78
    final_df.to_csv(os.path.join(args.results_dir, save_name))
79
80
# Training settings
81
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
82
parser.add_argument('--data_root_dir', type=str, help='data directory')
83
parser.add_argument('--max_epochs', type=int, default=200,
84
                    help='maximum number of epochs to train (default: 200)')
85
parser.add_argument('--lr', type=float, default=1e-4,
86
                    help='learning rate (default: 0.0001)')
87
parser.add_argument('--reg', type=float, default=1e-5,
88
                    help='weight decay (default: 1e-5)')
89
parser.add_argument('--seed', type=int, default=1, 
90
                    help='random seed for reproducible experiment (default: 1)')
91
parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)')
92
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
93
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
94
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)')
95
parser.add_argument('--split_dir', type=str, default=None, 
96
                    help='manually specify the set of splits to use, ' 
97
                    +'instead of infering from the task and label_frac argument (default: None)')
98
parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard')
99
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool')
100
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping')
101
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam')
102
parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)')
103
parser.add_argument('--exp_code', type=str, help='experiment code for saving results')
104
parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling')
105
parser.add_argument('--task', type=str, choices=['dummy_mtl_concat'])
106
args = parser.parse_args()
107
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
109
def seed_torch(seed=7):
110
    import random
111
    random.seed(seed)
112
    os.environ['PYTHONHASHSEED'] = str(seed)
113
    np.random.seed(seed)
114
    torch.manual_seed(seed)
115
    if device.type == 'cuda':
116
        torch.cuda.manual_seed(seed)
117
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
118
    torch.backends.cudnn.benchmark = False
119
    torch.backends.cudnn.deterministic = True
120
121
seed_torch(args.seed)
122
123
encoding_size = 1024
124
settings = {'num_splits': args.k, 
125
            'k_start': args.k_start,
126
            'k_end': args.k_end,
127
            'task': args.task,
128
            'max_epochs': args.max_epochs, 
129
            'results_dir': args.results_dir, 
130
            'lr': args.lr,
131
            'experiment': args.exp_code,
132
            'reg': args.reg,
133
            'seed': args.seed,
134
            "use_drop_out": args.drop_out,
135
            'weighted_sample': args.weighted_sample,
136
            'opt': args.opt}
137
138
print('\nLoad Dataset')
139
140
if args.task == 'dummy_mtl_concat':
141
    args.n_classes=18
142
    dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/dummy_dataset.csv',
143
                            data_dir= os.path.join(args.data_root_dir,'DUMMY_DATA_DIR'),
144
                            shuffle = False, 
145
                            seed = args.seed, 
146
                            print_info = True,
147
                            label_dicts = [{'Lung':0, 'Breast':1, 'Colorectal':2, 'Ovarian':3, 
148
                                            'Pancreatobiliary':4, 'Adrenal':5, 
149
                                             'Skin':6, 'Prostate':7, 'Renal':8, 'Bladder':9, 
150
                                             'Esophagogastric':10,  'Thyroid':11,
151
                                             'Head Neck':12,  'Glioma':13, 
152
                                             'Germ Cell':14, 'Endometrial': 15, 
153
                                             'Cervix': 16, 'Liver': 17},
154
                                            {'Primary':0,  'Metastatic':1},
155
                                            {'F':0, 'M':1}],
156
                            label_cols = ['label', 'site', 'sex'],
157
                            patient_strat= False)
158
else:
159
    raise NotImplementedError
160
161
162
    
163
if not os.path.isdir(args.results_dir):
164
    os.mkdir(args.results_dir)
165
166
args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed))
167
if not os.path.isdir(args.results_dir):
168
    os.mkdir(args.results_dir)
169
170
if args.split_dir is None:
171
    args.split_dir = os.path.join('splits', args.task+'_{}'.format(int(100)))
172
else:
173
    args.split_dir = os.path.join('splits', args.split_dir)
174
assert os.path.isdir(args.split_dir)
175
176
settings.update({'split_dir': args.split_dir})
177
178
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
179
    print(settings, file=f)
180
f.close()
181
182
print("################# Settings ###################")
183
for key, val in settings.items():
184
    print("{}:  {}".format(key, val))        
185
186
if __name__ == "__main__":
187
    results = main(args)
188
    print("finished!")
189
    print("end script")
190
191