Diff of /eval.py [000000] .. [4cd6c8]

Switch to unified view

a b/eval.py
1
from __future__ import print_function
2
3
import numpy as np
4
5
import argparse
6
import torch
7
import torch.nn as nn
8
import pdb
9
import os
10
import pandas as pd
11
from utils.utils import *
12
from math import floor
13
import matplotlib.pyplot as plt
14
from datasets.dataset_generic import Generic_WSI_Classification_Dataset, Generic_MIL_Dataset
15
#, save_splits
16
from datasets.dataset_mtl import Generic_MIL_MTL_Dataset
17
#, save_splits
18
import h5py
19
from utils.eval_utils import eval
20
from utils.eval_utils_mtl import eval as eval_mtl
21
22
23
# Training settings
24
parser = argparse.ArgumentParser(description='CLAM Evaluation Script')
25
parser.add_argument('--data_root_dir', type=str, default='/media/fedshyvana/ssd1',
26
                    help='data directory')
27
parser.add_argument('--results_dir', type=str, default='./results',
28
                    help='relative path to results folder, i.e. '+
29
                    'the directory containing models_exp_code relative to project root (default: ./results)')
30
parser.add_argument('--save_exp_code', type=str, default=None,
31
                    help='experiment code to save eval results')
32
parser.add_argument('--models_exp_code', type=str, default=None,
33
                    help='experiment code to load trained models (directory under results_dir containing model checkpoints')
34
parser.add_argument('--splits_dir', type=str, default=None,
35
                    help='splits directory, if using custom splits other than what matches the task (default: None)')
36
parser.add_argument('--model_size', type=str, choices=['small', 'big'], default='big',
37
                    help='size of model (default: big)')
38
parser.add_argument('--model_type', type=str, choices=['clam', 'mil', 'attention_mil', 'clam_simple','histogram_mil'], default='attention_mil',
39
                    help='type of model (default: attention_mil)')
40
parser.add_argument('--drop_out', action='store_true', default=False,
41
                    help='whether model uses dropout')
42
parser.add_argument('--calc_features', action='store_true', default=False,
43
                    help='calculate features for pca/tsne')
44
parser.add_argument('--k', type=int, default=1, help='number of folds (default: 10)')
45
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
46
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
47
parser.add_argument('--fold', type=int, default=-1, help='single fold to evaluate')
48
parser.add_argument('--micro_average', action='store_true', default=False,
49
                    help='use micro_average instead of macro_avearge for multiclass AUC')
50
parser.add_argument('--mtl', action='store_true', default=False, help='flag to enable multi-task problem')
51
parser.add_argument('--patient_level', action='store_true', default=False, help='To enable computing scores at the patient-level. I.e. all patients slides are treated as a single bag with a single label')
52
parser.add_argument('--split', type=str, choices=['train', 'val', 'test', 'all'], default='test')
53
parser.add_argument('--task', type=str,
54
choices=['cardiac-grade','cardiac-mtl'])
55
56
args = parser.parse_args()
57
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
encoding_size = 1024
59
60
args.save_dir = os.path.join('./eval_results', 'EVAL_' + str(args.save_exp_code))
61
args.models_dir = os.path.join(args.results_dir, str(args.models_exp_code))
62
63
os.makedirs(args.save_dir, exist_ok=True)
64
os.makedirs(os.path.join(args.save_dir, 'attention_scores'), exist_ok=True)
65
66
if args.splits_dir is None:
67
    args.splits_dir = args.models_dir
68
69
assert os.path.isdir(args.models_dir)
70
assert os.path.isdir(args.splits_dir)
71
72
settings = {'task': args.task,
73
            'split': args.split,
74
            'save_dir': args.save_dir,
75
            'models_dir': args.models_dir,
76
            'model_type': args.model_type,
77
            'drop_out': args.drop_out,
78
            'model_size': args.model_size,
79
            'micro_average': args.micro_average}
80
81
with open(args.save_dir + '/eval_experiment_{}.txt'.format(args.save_exp_code), 'w') as f:
82
    print(settings, file=f)
83
f.close()
84
85
print(settings)
86
87
if args.task == 'cardiac-grade':
88
    args.n_classes=2
89
    dataset = Generic_MIL_Dataset(csv_path = 'dataset_csv/CardiacDummy_Grade.csv',
90
                            data_dir= os.path.join(args.data_root_dir, 'features'),
91
                            shuffle = False,
92
                            print_info = True,
93
                            label_dict = {'low':0, 'high':1},
94
                            patient_strat= False,
95
                            ignore=[],
96
                patient_level = args.patient_level)
97
98
99
elif args.task == 'cardiac-mtl':
100
    args.n_classes = [2,2,2]
101
    dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/CardiacDummy_MTL.csv',
102
                                    data_dir= os.path.join(args.data_root_dir, 'features'),
103
                                    shuffle = False,
104
                                    print_info = True,
105
                                     label_dicts = [{'no_cell':0, 'cell':1},
106
                                                    {'no_amr':0, 'amr':1},
107
                                                    {'no_quilty':0, 'quilty':1}],
108
                                    label_cols=['label_cell','label_amr','label_quilty'],
109
                                    patient_strat= False,
110
                                    ignore=[],
111
                    patient_level = args.patient_level)
112
113
114
elif os.path.isdir(args.task):
115
    print('reading directory for fast inference')
116
117
if args.k_start == -1:
118
    start = 0
119
else:
120
    start = args.k_start
121
if args.k_end == -1:
122
    end = args.k
123
else:
124
    end = args.k_end
125
126
if args.fold == -1:
127
    folds = range(start, end)
128
else:
129
    folds = range(args.fold, args.fold+1)
130
ckpt_paths = [os.path.join(args.models_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
131
datasets_id = {'train': 0, 'val': 1, 'test': 2, 'all': -1}
132
133
134
135
136
137
def main(args):
138
    all_auc = []
139
    all_acc = []
140
    all_aucs = []
141
    for ckpt_idx in range(len(ckpt_paths)):
142
        if datasets_id[args.split] < 0:
143
            split_dataset = dataset
144
        else:
145
            csv_path = '{}/splits_{}.csv'.format(args.splits_dir, folds[ckpt_idx])
146
            datasets = dataset.return_splits(from_id=False, csv_path=csv_path)
147
            split_dataset = datasets[datasets_id[args.split]]
148
        
149
        model, patient_results, test_error, auc, aucs, df  = eval(split_dataset, args, ckpt_paths[ckpt_idx])
150
        all_auc.append(auc)
151
        all_acc.append(1-test_error)
152
        if len(aucs) > 0:
153
            all_aucs.append(aucs)
154
        df.to_csv(os.path.join(args.save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
155
156
        if args.calc_features:
157
            compute_features(split_dataset, args, ckpt_paths[ckpt_idx], args.save_dir, model=model)
158
159
    df_dict = {'folds': folds, 'test_auc': all_auc, 'test_acc': all_acc}
160
161
    if args.n_classes > 2:
162
        all_aucs = np.vstack(all_aucs)
163
        for i in range(args.n_classes):
164
            df_dict.update({'class_{}_ovr_auc'.format(i):all_aucs[:,i]})
165
166
    final_df = pd.DataFrame(df_dict)
167
    if len(folds) != args.k:
168
        save_name = 'summary_partial_{}_{}.csv'.format(folds[0], folds[-1])
169
    else:
170
        save_name = 'summary.csv'
171
    final_df.to_csv(os.path.join(args.save_dir, save_name))
172
173
174
175
176
def main_mtl(args):
177
    all_task1_auc = []
178
    all_task1_acc = []
179
    all_task2_auc = []
180
    all_task2_acc = []
181
    all_task3_auc = []
182
    all_task3_acc = []
183
184
    for ckpt_idx in range(len(ckpt_paths)):
185
        if datasets_id[args.split] < 0:
186
            split_dataset = dataset
187
        else:
188
            csv_path = '{}/splits_{}.csv'.format(args.splits_dir, folds[ckpt_idx])
189
            datasets = dataset.return_splits(from_id=False, csv_path=csv_path)
190
            split_dataset = datasets[datasets_id[args.split]]
191
192
        model, results_dict = eval_mtl(split_dataset, args, ckpt_paths[ckpt_idx])
193
194
        all_task1_auc.append(results_dict['auc_task1'])
195
        all_task1_acc.append(1-results_dict['test_error_task1'])
196
        all_task2_auc.append(results_dict['auc_task2'])
197
        all_task2_acc.append(1-results_dict['test_error_task2'])
198
        all_task3_auc.append(results_dict['auc_task3'])
199
        all_task3_acc.append(1-results_dict['test_error_task3'])
200
201
        df = results_dict['df']
202
        df.to_csv(os.path.join(args.save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
203
204
        if args.calc_features:
205
            compute_features(split_dataset, args, ckpt_paths[ckpt_idx], args.save_dir, model=model)
206
207
    df_dict = {'folds': folds,
208
                'task1_test_auc': all_task1_auc, 'task1_test_acc': all_task1_acc,
209
                'task2_test_auc': all_task2_auc, 'task2_test_acc': all_task2_acc,
210
                'task3_test_auc': all_task3_auc, 'task3_test_acc': all_task3_acc}
211
212
213
    final_df = pd.DataFrame(df_dict)
214
    if len(folds) != args.k:
215
        save_name = 'summary_partial_{}_{}.csv'.format(folds[0], folds[-1])
216
    else:
217
        save_name = 'summary.csv'
218
    final_df.to_csv(os.path.join(args.save_dir, save_name))
219
220
221
222
223
224
if __name__ == "__main__":
225
    if args.mtl:
226
        main_mtl(args)
227
    else:
228
        main(args)
229
230
    print("finished!")
231
    print("end script")