Diff of /main.py [000000] .. [4cd6c8]

Switch to unified view

a b/main.py
1
from __future__ import print_function
2
3
import argparse
4
import pdb
5
import os
6
import math
7
8
# internal imports
9
from utils.file_utils import save_pkl, load_pkl
10
from utils.utils import *
11
from utils.core_utils import train
12
from utils.core_utils_mtl import train as train_mtl
13
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
14
from datasets.dataset_mtl import Generic_WSI_MTL_Dataset, Generic_MIL_MTL_Dataset
15
16
# pytorch imports
17
import torch
18
from torch.utils.data import DataLoader, sampler
19
import torch.nn as nn
20
import torch.nn.functional as F
21
22
import pandas as pd
23
import numpy as np
24
25
26
# Rejection grade:
27
# binary classifier:
28
# class 0 - low grade
29
# class 1 - high grade
30
#-------------------------------
31
def main_grade(args):
32
    print("-----------------------------------------")
33
    print(" Grade Net (single task binary classifier")
34
    print("-----------------------------------------")
35
36
    # create results directory if necessary
37
    if not os.path.isdir(args.results_dir):
38
        os.mkdir(args.results_dir)
39
40
    if args.k_start == -1:
41
        start = 0
42
    else:
43
        start = args.k_start
44
    if args.k_end == -1:
45
        end = args.k
46
    else:
47
        end = args.k_end
48
49
    all_test_auc = []
50
    all_val_auc = []
51
    all_test_acc = []
52
    all_val_acc = []
53
    folds = np.arange(start, end)
54
    for i in folds:
55
        seed_torch(args.seed)
56
        train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False,
57
                csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
58
59
        datasets = (train_dataset, val_dataset, test_dataset)
60
        results, test_auc, val_auc, test_acc, val_acc  = train(datasets, i, args)
61
        all_test_auc.append(test_auc)
62
        all_val_auc.append(val_auc)
63
        all_test_acc.append(test_acc)
64
        all_val_acc.append(val_acc)
65
        #write results to pkl
66
        filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
67
        save_pkl(filename, results)
68
69
    final_df = pd.DataFrame({'folds': folds, 'test_auc': all_test_auc,
70
        'val_auc': all_val_auc, 'test_acc': all_test_acc, 'val_acc' : all_val_acc})
71
72
    if len(folds) != args.k:
73
        save_name = 'summary_partial_{}_{}.csv'.format(start, end)
74
    else:
75
        save_name = 'summary.csv'
76
    final_df.to_csv(os.path.join(args.results_dir, save_name))
77
78
79
# Multi task classifier for EMB evaluation:
80
# consist of 3 simultaneous tasks:
81
# task1: cellular vs non-cellular
82
# task2: antibody vs non-antibody
83
# task3: quilty lesion vs no quilty lesion
84
#-------------------------------------------
85
def main_mtl(args):
86
87
    print("----------------------------------------")
88
    print(" EMB assessment - multi task classifier ")
89
    print("----------------------------------------")
90
   
91
   # create results directory if necessary
92
    if not os.path.isdir(args.results_dir):
93
        os.mkdir(args.results_dir)
94
95
    if args.k_start == -1:
96
        start = 0
97
    else:
98
        start = args.k_start
99
    if args.k_end == -1:
100
        end = args.k
101
    else:
102
        end = args.k_end
103
104
# arrays to collect scores -- replace by generic one when refactoring
105
    all_task1_test_auc = []
106
    all_task1_val_auc  = []
107
    all_task1_test_acc = []
108
    all_task1_val_acc  = []
109
110
    all_task2_test_auc = []
111
    all_task2_val_auc  = []
112
    all_task2_test_acc = []
113
    all_task2_val_acc  = []
114
115
    all_task3_test_auc = []
116
    all_task3_val_auc  = []
117
    all_task3_test_acc = []
118
    all_task3_val_acc  = []
119
120
121
    folds = np.arange(start, end)
122
    for i in folds:
123
        seed_torch(args.seed)
124
        train_dataset, val_dataset, test_dataset = dataset.return_splits(from_id=False,
125
                csv_path='{}/splits_{}.csv'.format(args.split_dir, i))
126
127
        print('training: {}, validation: {}, testing: {}'.format(len(train_dataset), len(val_dataset), len(test_dataset)))
128
        datasets = (train_dataset, val_dataset, test_dataset)
129
130
        results, \
131
        task1_test_auc, task1_val_auc, task1_test_acc, task1_val_acc, \
132
        task2_test_auc, task2_val_auc, task2_test_acc, task2_val_acc, \
133
        task3_test_auc, task3_val_auc, task3_test_acc, task3_val_acc = train_mtl(datasets, i, args)
134
135
        all_task1_test_auc.append(task1_test_auc)
136
        all_task1_val_auc.append( task1_val_auc )
137
        all_task1_test_acc.append(task1_test_acc)
138
        all_task1_val_acc.append( task1_val_acc )
139
140
        all_task2_test_auc.append(task2_test_auc)
141
        all_task2_val_auc.append( task2_val_auc )
142
        all_task2_test_acc.append(task2_test_acc)
143
        all_task2_val_acc.append( task2_val_acc )
144
145
        all_task3_test_auc.append(task3_test_auc)
146
        all_task3_val_auc.append( task3_val_auc )
147
        all_task3_test_acc.append(task3_test_acc)
148
        all_task3_val_acc.append( task3_val_acc )
149
150
        #write results to pkl
151
        filename = os.path.join(args.results_dir, 'split_{}_results.pkl'.format(i))
152
        save_pkl(filename, results)
153
154
    final_df = pd.DataFrame({'folds': folds,
155
        'task1_test_auc': all_task1_test_auc, 'task1_val_auc': all_task1_val_auc,
156
        'task1_test_acc': all_task1_test_acc, 'task1_val_acc': all_task1_val_acc,
157
        'task2_test_auc': all_task2_test_auc, 'task2_val_auc': all_task2_val_auc,
158
        'task2_test_acc': all_task2_test_acc, 'task2_val_acc': all_task2_val_acc,
159
        'task3_test_auc': all_task3_test_auc, 'task3_val_auc': all_task3_val_auc,
160
        'task3_test_acc': all_task3_test_acc, 'task3_val_acc': all_task3_val_acc})
161
162
    if len(folds) != args.k:
163
        save_name = 'summary_partial_{}_{}.csv'.format(start, end)
164
    else:
165
        save_name = 'summary.csv'
166
    final_df.to_csv(os.path.join(args.results_dir, save_name))
167
168
169
# Training settings
170
parser = argparse.ArgumentParser(description='Configurations for WSI Training')
171
parser.add_argument('--data_root_dir', type=str, default='/media/fedshyvana/ssd1',
172
                    help='data directory')
173
parser.add_argument('--max_epochs', type=int, default=200,
174
                    help='maximum number of epochs to train (default: 200)')
175
parser.add_argument('--lr', type=float, default=1e-4,
176
                    help='learning rate (default: 0.0001)')
177
parser.add_argument('--label_frac', type=float, default=1.0,
178
                    help='fraction of training labels (default: 1.0)')
179
parser.add_argument('--bag_weight', type=float, default=0.7,
180
                    help='clam: weight coefficient for bag-level loss (default: 0.7)')
181
parser.add_argument('--reg', type=float, default=1e-5,
182
                    help='weight decay (default: 1e-5)')
183
parser.add_argument('--seed', type=int, default=1,
184
                    help='random seed for reproducible experiment (default: 1)')
185
parser.add_argument('--k', type=int, default=10, help='number of folds (default: 10)')
186
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
187
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
188
parser.add_argument('--results_dir', default='./results', help='results directory (default: ./results)')
189
parser.add_argument('--split_dir', type=str, default=None,
190
                    help='manually specify the set of splits to use, '
191
                    +'instead of infering from the task and label_frac argument (default: None)')
192
parser.add_argument('--log_data', action='store_true', default=False, help='log data using tensorboard')
193
parser.add_argument('--testing', action='store_true', default=False, help='debugging tool')
194
parser.add_argument('--subtyping', action='store_true', default=False, help='subtyping problem')
195
parser.add_argument('--early_stopping', action='store_true', default=False, help='enable early stopping')
196
parser.add_argument('--opt', type=str, choices = ['adam', 'sgd'], default='adam')
197
parser.add_argument('--drop_out', action='store_true', default=False, help='enabel dropout (p=0.25)')
198
parser.add_argument('--inst_loss', type=str, choices=['svm', 'ce', None], default=None,
199
                     help='instance-level clustering loss function (default: None)')
200
parser.add_argument('--bag_loss', type=str, choices=['svm', 'ce'], default='ce',
201
                     help='slide-level classification loss function (default: ce)')
202
parser.add_argument('--model_type', type=str, choices=['clam', 'mil', 'clam_simple', 'attention_mil', 'histogram_mil'], default='attention_mil', help='type of model (default: attention_mil)')
203
parser.add_argument('--exp_code', type=str, help='experiment code for saving results')
204
parser.add_argument('--weighted_sample', action='store_true', default=False, help='enable weighted sampling')
205
parser.add_argument('--model_size', type=str, choices=['small', 'big'], default='big', help='size of model')
206
parser.add_argument('--mtl', action='store_true', default=False, help='flag to enable multi-task problem')
207
parser.add_argument('--task', type=str, choices=['cardiac-grade','cardiac-mtl'])
208
209
210
args = parser.parse_args()
211
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
213
def seed_torch(seed=7):
214
    import random
215
    random.seed(seed)
216
    os.environ['PYTHONHASHSEED'] = str(seed)
217
    np.random.seed(seed)
218
    torch.manual_seed(seed)
219
    if device.type == 'cuda':
220
        torch.cuda.manual_seed(seed)
221
        torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
222
    torch.backends.cudnn.benchmark = False
223
    torch.backends.cudnn.deterministic = True
224
225
seed_torch(args.seed)
226
227
encoding_size = 1024
228
settings = {'num_splits': args.k,
229
            'k_start': args.k_start,
230
            'k_end': args.k_end,
231
            'task': args.task,
232
            'max_epochs': args.max_epochs,
233
            'results_dir': args.results_dir,
234
            'lr': args.lr,
235
            'experiment': args.exp_code,
236
            'reg': args.reg,
237
            'label_frac': args.label_frac,
238
            'inst_loss': args.inst_loss,
239
            'bag_loss': args.bag_loss,
240
            'bag_weight': args.bag_weight,
241
            'seed': args.seed,
242
            'model_type': args.model_type,
243
            'model_size': args.model_size,
244
            "use_drop_out": args.drop_out,
245
            'weighted_sample': args.weighted_sample,
246
            'opt': args.opt}
247
248
249
print('\nLoad Dataset')
250
if args.task == 'cardiac-grade':
251
    args.n_classes=2
252
    dataset = Generic_MIL_Dataset(csv_path = 'dataset_csv/CardiacDummy_Grade.csv',
253
                            data_dir= os.path.join(args.data_root_dir, 'features'),
254
                            shuffle = False,
255
                            seed = args.seed,
256
                            print_info = True,
257
                            label_dict = {'low':0, 'high':1},
258
                            label_cols=['label_grade'],
259
                            patient_strat=False,
260
                            ignore=[])
261
262
263
elif args.task == 'cardiac-mtl':
264
    args.n_classes=[2,2,2]  
265
    dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTL.csv',
266
                            data_dir= os.path.join(args.data_root_dir, 'features'),
267
                            shuffle = False,
268
                            seed = args.seed,
269
                            print_info = True,
270
                            label_dicts = [{'no_cell':0, 'cell':1},
271
                                            {'no_amr':0, 'amr':1},
272
                                            {'no_quilty':0, 'quilty':1}],
273
                            label_cols=['label_cell','label_amr','label_quilty'],
274
                            patient_strat=False,
275
                            ignore=[])
276
277
278
else:
279
    raise NotImplementedError
280
281
if not os.path.isdir(args.results_dir):
282
    os.mkdir(args.results_dir)
283
284
args.results_dir = os.path.join(args.results_dir, str(args.exp_code) + '_s{}'.format(args.seed))
285
if not os.path.isdir(args.results_dir):
286
    os.mkdir(args.results_dir)
287
288
if args.split_dir is None:
289
    args.split_dir = os.path.join('splits', args.task+'_{}'.format(int(args.label_frac*100)))
290
291
else:
292
    args.split_dir = os.path.join('splits', args.split_dir)
293
assert os.path.isdir(args.split_dir)
294
295
settings.update({'split_dir': args.split_dir})
296
297
298
with open(args.results_dir + '/experiment_{}.txt'.format(args.exp_code), 'w') as f:
299
    print(settings, file=f)
300
f.close()
301
302
print("################# Settings ###################")
303
for key, val in settings.items():
304
    print("{}:  {}".format(key, val))
305
306
if __name__ == "__main__":
307
    if args.mtl:
308
        results = main_mtl(args)
309
    else:
310
        results = main_grade(args)
311
312
    print("finished!")
313
    print("end script")