a b/utils/eval_utils_mtl.py
1
import numpy as np
2
3
import torch
4
import torch.nn as nn
5
import torch.nn.functional as F
6
from models.model_attention_mil import MIL_Attention_fc_mtl
7
import pdb
8
import os
9
import pandas as pd
10
from utils.utils import *
11
from utils.core_utils import EarlyStopping,  Accuracy_Logger
12
from utils.file_utils import save_pkl, load_pkl
13
from sklearn.metrics import roc_auc_score, roc_curve, auc
14
import h5py
15
from models.resnet_custom import resnet50_baseline
16
import math
17
from sklearn.preprocessing import label_binarize
18
19
def initiate_model(args, ckpt_path=None):
20
    print('Init Model')
21
    model_dict = {"dropout": args.drop_out, 'n_classes': args.n_classes}
22
23
    if args.model_size is not None and args.model_type in ['clam', 'attention_mil', 'clam_simple']:
24
        model_dict.update({"size_arg": args.model_size})
25
26
    if args.model_type =='clam':
27
        raise NotImplementedError
28
    elif args.model_type =='clam_simple':
29
        raise NotImplementedError
30
    elif args.model_type == 'attention_mil':
31
        model = MIL_Attention_fc_mtl(**model_dict)
32
    else: # args.model_type == 'mil'
33
        raise NotImplementedError
34
35
    #model.relocate()
36
    print_network(model)
37
38
    if ckpt_path is not None:
39
        ckpt = torch.load(ckpt_path)
40
        ckpt_clean = {}
41
        for key in ckpt.keys():
42
            if 'instance_loss_fn' in key:
43
                continue
44
            ckpt_clean.update({key.replace('.module', ''):ckpt[key]})
45
        model.load_state_dict(ckpt_clean, strict=True)
46
    model.relocate()
47
    model.eval()
48
    return model
49
50
51
def eval(dataset, args, ckpt_path):
52
    model = initiate_model(args, ckpt_path)
53
54
55
    print('Init Loaders')
56
    loader = get_simple_loader(dataset, collate_fn='MIL_mtl')
57
    results_dict = summary(model, loader, args)
58
59
    print('test_error_task1: ', results_dict['test_error_task1'])
60
    print('auc_task1: ',        results_dict['auc_task1'])
61
    print('test_error_task2: ', results_dict['test_error_task2'])
62
    print('auc_task2: ',        results_dict['auc_task2'])
63
    print('test_error_task3: ', results_dict['test_error_task3'])
64
    print('auc_task3: ',        results_dict['auc_task3'])
65
66
    return model, results_dict
67
    # patient_results, test_error, auc, aucs, df
68
69
def infer(dataset, args, ckpt_path, class_labels, site_labels):
70
    model = initiate_model(args, ckpt_path)
71
    df = infer_dataset(model, dataset, args, class_labels, site_labels)
72
    return model, df
73
74
# Code taken from pytorch/examples for evaluating topk classification on on ImageNet
75
def accuracy(output, target, topk=(1,)):
76
    """Computes the accuracy over the k top predictions for the specified values of k"""
77
    with torch.no_grad():
78
        maxk = max(topk)
79
        batch_size = target.size(0)
80
81
        _, pred = output.topk(maxk, 1, True, True)
82
        pred = pred.t()
83
        correct = pred.eq(target.view(1, -1).expand_as(pred))
84
85
        res = []
86
        for k in topk:
87
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
88
            res.append(correct_k.mul_(1.0 / batch_size))
89
        return res
90
91
def summary(model, loader, args):
92
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
93
    logger_task1 = Accuracy_Logger(n_classes=args.n_classes[0])
94
    logger_task2 = Accuracy_Logger(n_classes=args.n_classes[1])
95
    logger_task3 = Accuracy_Logger(n_classes=args.n_classes[2])
96
    model.eval()
97
98
    test_error_task1 = 0.
99
    test_loss_task1  = 0.
100
    test_error_task2 = 0.
101
    test_loss_task2  = 0.
102
    test_error_task3 = 0.
103
    test_loss_task3  = 0.
104
105
    all_probs_task1  = np.zeros((len(loader), args.n_classes[0]))
106
    all_labels_task1 = np.zeros(len(loader))
107
    all_probs_task2  = np.zeros((len(loader), args.n_classes[1]))
108
    all_labels_task2 = np.zeros(len(loader))
109
    all_probs_task3  = np.zeros((len(loader), args.n_classes[2]))
110
    all_labels_task3 = np.zeros(len(loader))
111
112
    if not args.patient_level:
113
        slide_ids = loader.dataset.slide_data['slide_id']
114
        patient_results = {}
115
116
        for batch_idx, (data, label_task1, label_task2, label_task3) in enumerate(loader):
117
            data =  data.to(device)     
118
            label_task1 = label_task1.to(device)
119
            label_task2 = label_task2.to(device)            
120
            label_task3 = label_task3.to(device)
121
122
            slide_id = slide_ids.iloc[batch_idx]
123
            with torch.no_grad():
124
                model_results_dict = model(data)
125
126
            logits_task1, Y_prob_task1, Y_hat_task1  = model_results_dict['logits_task1'], model_results_dict['Y_prob_task1'], model_results_dict['Y_hat_task1']
127
            logits_task2, Y_prob_task2, Y_hat_task2  = model_results_dict['logits_task2'], model_results_dict['Y_prob_task2'], model_results_dict['Y_hat_task2']
128
            logits_task3, Y_prob_task3, Y_hat_task3  = model_results_dict['logits_task3'], model_results_dict['Y_prob_task3'], model_results_dict['Y_hat_task3']
129
            del model_results_dict
130
131
            logger_task1.log(Y_hat_task1, label_task1)
132
            logger_task2.log(Y_hat_task2, label_task2)
133
            logger_task3.log(Y_hat_task3, label_task3)
134
135
            probs_task1 = Y_prob_task1.cpu().numpy()
136
            all_probs_task1[batch_idx] = probs_task1
137
            all_labels_task1[batch_idx] = label_task1.item()
138
139
            probs_task2 = Y_prob_task2.cpu().numpy()
140
            all_probs_task2[batch_idx] = probs_task2
141
            all_labels_task2[batch_idx] = label_task2.item()
142
143
            probs_task3 = Y_prob_task3.cpu().numpy()
144
            all_probs_task3[batch_idx] = probs_task3
145
            all_labels_task3[batch_idx] = label_task3.item()
146
147
            patient_results.update({slide_id: {'slide_id': np.array(slide_id),
148
                    'prob_task1': probs_task1, 'label_task1': label_task1.item(),
149
                    'prob_task2': probs_task2, 'label_task2': label_task2.item(),
150
                    'prob_task3': probs_task3, 'label_task3': label_task3.item() }})
151
152
            error_task1 = calculate_error(Y_hat_task1, label_task1)
153
            test_error_task1 += error_task1
154
            error_task2 = calculate_error(Y_hat_task2, label_task2)
155
            test_error_task2 += error_task2
156
            error_task3 = calculate_error(Y_hat_task3, label_task3)
157
            test_error_task3 += error_task3
158
    else:
159
        case_ids = loader.dataset.slide_data['case_id']
160
        patient_results = {}
161
162
        for batch_idx, (data, label_task1, label_task2, label_task3) in enumerate(loader):
163
            data =  data.to(device)
164
            label_task1 = label_task1.to(device)
165
            label_task2 = label_task2.to(device)
166
            label_task3 = label_task3.to(device)
167
168
            case_id = case_ids.iloc[batch_idx]
169
            with torch.no_grad():
170
                model_results_dict = model(data)
171
172
            logits_task1, Y_prob_task1, Y_hat_task1  = model_results_dict['logits_task1'], model_results_dict['Y_prob_task1'], model_results_dict['Y_hat_task1']
173
            logits_task2, Y_prob_task2, Y_hat_task2  = model_results_dict['logits_task2'], model_results_dict['Y_prob_task2'], model_results_dict['Y_hat_task2']
174
            logits_task3, Y_prob_task3, Y_hat_task3  = model_results_dict['logits_task3'], model_results_dict['Y_prob_task3'], model_results_dict['Y_hat_task3']
175
            del model_results_dict
176
177
            logger_task1.log(Y_hat_task1, label_task1)
178
            logger_task2.log(Y_hat_task2, label_task2)
179
            logger_task3.log(Y_hat_task3, label_task3)
180
181
            probs_task1 = Y_prob_task1.cpu().numpy()
182
            all_probs_task1[batch_idx] = probs_task1
183
            all_labels_task1[batch_idx] = label_task1.item()
184
185
            probs_task2 = Y_prob_task2.cpu().numpy()
186
            all_probs_task2[batch_idx] = probs_task2
187
            all_labels_task2[batch_idx] = label_task2.item()
188
189
            probs_task3 = Y_prob_task3.cpu().numpy()
190
            all_probs_task3[batch_idx] = probs_task3
191
            all_labels_task3[batch_idx] = label_task3.item()
192
193
            patient_results.update({case_id: {'case_id': np.array(case_id),
194
                                    'prob_task1': probs_task1, 'label_task1': label_task1.item(),
195
                                    'prob_task2': probs_task2, 'label_task2': label_task2.item(),
196
                                    'prob_task3': probs_task3, 'label_task3': label_task3.item() }})
197
198
            error_task1 = calculate_error(Y_hat_task1, label_task1)
199
            test_error_task1 += error_task1
200
            error_task2 = calculate_error(Y_hat_task2, label_task2)
201
            test_error_task2 += error_task2
202
            error_task3 = calculate_error(Y_hat_task3, label_task3)
203
            test_error_task3 += error_task3
204
205
206
    test_error_task1 /= len(loader)
207
    test_error_task2 /= len(loader)
208
    test_error_task3 /= len(loader)
209
210
    all_preds_task1 = np.argmax(all_probs_task1, axis=1)
211
    all_preds_task2 = np.argmax(all_probs_task2, axis=1)
212
    all_preds_task3 = np.argmax(all_probs_task3, axis=1)
213
214
        
215
216
    #if args.n_classes > 2:
217
    #    acc1, acc3 = accuracy(torch.from_numpy(all_cls_probs), torch.from_numpy(all_cls_labels), topk=(1, 3))
218
    #    print('top1 acc: {:.3f}, top3 acc: {:.3f}'.format(acc1.item(), acc3.item()))
219
220
    # IF MORE THAN BINARY CLASSIFICATION
221
    #if len(np.unique(all_labels_task1)) == 1:
222
    #    auc_task1 = -1
223
    #    aucs_task1 = []
224
    # else:
225
    #     if args.n_classes[0] == 2:
226
    #         auc_task1 = roc_auc_score(all_labels_task1, all_probs_task1[:, 1])
227
    #         aucs_task1 = []
228
    #     else:
229
    #         aucs_task1 = []
230
    #         binary_labels = label_binarize(all_labels_task1, classes=[i for i in range(args.n_classes[0])])
231
    #         for class_idx in range(args.n_classes[0[]]):
232
    #             if class_idx in all_labels_task1:
233
    #                 fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], all_probs_task1[:, class_idx])
234
    #                 aucs_task1.append(auc(fpr, tpr))
235
    #             else:
236
    #                 aucs_task1.append(float('nan'))
237
    #         if args.micro_average:
238
    #             binary_labels = label_binarize(all_labels_task1, classes=[i for i in range(args.n_classes[0])])
239
    #             fpr, tpr, _ = roc_curve(binary_labels.ravel(), all_probs_task1.ravel())
240
    #             auc_task1 = auc(fpr, tpr)
241
    #         else:
242
    #             auc_task1 = np.nanmean(np.array(aucs_task1))
243
244
245
    # ASSUME BINARY CLASSIFICATION
246
    if len(np.unique(all_labels_task1)) == 1:
247
        auc_task1 = -1
248
    else:
249
        auc_task1 = roc_auc_score(all_labels_task1, all_probs_task1[:, 1])
250
251
    if len(np.unique(all_labels_task2)) == 1:
252
        auc_task2 = -1
253
    else:
254
        auc_task2 = roc_auc_score(all_labels_task2, all_probs_task2[:, 1])
255
256
    if len(np.unique(all_labels_task3)) == 1:
257
        auc_task3 = -1
258
    else:
259
        auc_task3 = roc_auc_score(all_labels_task3, all_probs_task3[:, 1])
260
261
    if not args.patient_level:
262
        results_dict = {'slide_id': slide_ids,
263
            'Y_task1': all_labels_task1, 'Y_hat_task1': all_preds_task1,
264
            'Y_task2': all_labels_task2, 'Y_hat_task2': all_preds_task2,
265
            'Y_task3': all_labels_task3, 'Y_hat_task3': all_preds_task3}
266
    else:
267
        results_dict = {'case_id': case_ids,
268
                        'Y_task1': all_labels_task1, 'Y_hat_task1': all_preds_task1,
269
                        'Y_task2': all_labels_task2, 'Y_hat_task2': all_preds_task2,
270
                        'Y_task3': all_labels_task3, 'Y_hat_task3': all_preds_task3}
271
    
272
    results_dict.update({'p0_task1': all_probs_task1[:,0]})
273
    results_dict.update({'p1_task1': all_probs_task1[:,1]})
274
    results_dict.update({'p0_task2': all_probs_task2[:,0]})
275
    results_dict.update({'p1_task2': all_probs_task2[:,1]})
276
    results_dict.update({'p0_task3': all_probs_task3[:,0]})
277
    results_dict.update({'p1_task3': all_probs_task3[:,1]})
278
279
    df = pd.DataFrame(results_dict)
280
281
    if args.patient_level:
282
        df = df.drop_duplicates(subset=['case_id'])
283
284
    inference_results = {'patient_results': patient_results,
285
                        'test_error_task1': test_error_task1, 'auc_task1': auc_task1,
286
                        'test_error_task2': test_error_task2, 'auc_task2': auc_task2,
287
                        'test_error_task3': test_error_task3, 'auc_task3': auc_task3,
288
                        'loggers': (logger_task1, logger_task2, logger_task3), 'df':df}
289
290
    return inference_results
291
292
293
def infer_dataset(model, dataset, args, class_labels, site_labels, k=3):
294
    model.eval()
295
    all_probs_cls = np.zeros((len(dataset), k))
296
    all_probs_site = np.zeros((len(dataset),2))
297
298
    all_preds_cls = np.zeros((len(dataset), k))
299
    all_preds_cls_str = np.full((len(dataset), k), ' ', dtype=object)
300
    all_preds_site = np.full((len(dataset)), ' ', dtype=object)
301
302
    slide_ids = dataset.slide_data
303
    for batch_idx, data in enumerate(dataset):
304
        data = data.to(device)
305
        with torch.no_grad():
306
            results_dict = model(data)
307
308
        Y_prob, Y_hat = results_dict['Y_prob'], results_dict['Y_hat']
309
        site_prob, site_hat = results_dict['site_prob'], results_dict['site_hat']
310
        del results_dict
311
        probs, ids = torch.topk(Y_prob, k)
312
        probs = probs.cpu().numpy()
313
        site_prob = site_prob.cpu().numpy()
314
        ids = ids.cpu().numpy()
315
        all_probs_cls[batch_idx] = probs
316
        all_preds_cls[batch_idx] = ids
317
        all_preds_cls_str[batch_idx] = np.array(class_labels)[ids]
318
319
        all_probs_site[batch_idx] = site_prob
320
        all_preds_site[batch_idx] = np.array(site_labels)[site_hat.item()]
321
322
    del data
323
    results_dict = {'slide_id': slide_ids}
324
    for c in range(k):
325
        results_dict.update({'Pred_{}'.format(c): all_preds_cls_str[:, c]})
326
        results_dict.update({'p_{}'.format(c): all_probs_cls[:, c]})
327
    results_dict.update({'Site_Pred': all_preds_site, 'Site_p': all_probs_site[:, 1]})
328
    df = pd.DataFrame(results_dict)
329
    return df
330
331
# def infer_dataset(model, dataset, args, class_labels, k=3):
332
#     model.eval()
333
334
#     all_probs = np.zeros((len(dataset), args.n_classes))
335
#     all_preds = np.zeros(len(dataset))
336
#     all_str_preds = np.full(len(dataset), ' ', dtype=object)
337
338
#     slide_ids = dataset.slide_data
339
#     for batch_idx, data in enumerate(dataset):
340
#         data = data.to(device)
341
#         with torch.no_grad():
342
#             logits, Y_prob, Y_hat, _, results_dict = model(data)
343
344
#         probs = Y_prob.cpu().numpy()
345
#         all_probs[batch_idx] = probs
346
#         all_preds[batch_idx] = Y_hat.item()
347
#         all_str_preds[batch_idx] = class_labels[Y_hat.item()]
348
#     del data
349
350
#     results_dict = {'slide_id': slide_ids, 'Prediction': all_str_preds, 'Y_hat': all_preds}
351
#     for c in range(args.n_classes):
352
#         results_dict.update({'p_{}_{}'.format(c, class_labels[c]): all_probs[:,c]})
353
#     df = pd.DataFrame(results_dict)
354
#     return df
355
356
def compute_features(dataset, args, ckpt_path, save_dir, model=None, feature_dim=512):
357
    if model is None:
358
        model = initiate_model(args, ckpt_path)
359
360
    names = dataset.get_list(np.arange(len(dataset))).values
361
    file_path = os.path.join(save_dir, 'features.h5')
362
363
    initialize_features_hdf5_file(file_path, len(dataset), feature_dim=feature_dim, names=names)
364
    for i in range(len(dataset)):
365
        print("Progress: {}/{}".format(i, len(dataset)))
366
        save_features(dataset, i, model, args, file_path)
367
368
def save_features(dataset, idx, model, args, save_file_path):
369
    name = dataset.get_list(idx)
370
    print(name)
371
    features, label_task1, label_task2, label_task3 = dataset[idx]
372
    features = features.to(device)
373
    with torch.no_grad():
374
        results_dict = model(features, return_features=True)
375
        Y_prob_task1, Y_hat_task1 = results_dict['Y_prob_task1'], results_dict['Y_hat_task1']
376
        Y_prob_task2, Y_hat_task2 = results_dict['Y_prob_task2'], results_dict['Y_hat_task2']
377
        Y_prob_task3, Y_hat_task3 = results_dict['Y_prob_task3'], results_dict['Y_hat_task3']
378
379
        feat_task1 = results_dict['features'][0]
380
        feat_task2 = results_dict['features'][1]
381
        feat_task3 = results_dict['features'][2]
382
383
    del results_dict
384
    del features
385
386
    Y_hat_task1  = Y_hat_task1.item()
387
    Y_prob_task1 = Y_prob_task1.view(-1).cpu().numpy()
388
    Y_hat_task2  = Y_hat_task2.item()
389
    Y_prob_task2 = Y_prob_task2.view(-1).cpu().numpy()
390
    Y_hat_task3  = Y_hat_task3.item()
391
    Y_prob_task3 = Y_prob_task3.view(-1).cpu().numpy()
392
393
    feat_task1 = feat_task1.view(1, -1).cpu().numpy()
394
    feat_task2 = feat_task2.view(1, -1).cpu().numpy()
395
    feat_task3 = feat_task3.view(1, -1).cpu().numpy()
396
397
    with h5py.File(save_file_path, 'r+') as file:
398
        print('label_task1', label_task1)
399
        file['features_task1'][idx, :] = feat_task1
400
        file['features_task2'][idx, :] = feat_task2
401
        file['features_task3'][idx, :] = feat_task3
402
        file['label_task1'][idx] = label_task1
403
        file['Y_hat_task1'][idx] = Y_hat_task1
404
        file['Y_prob_task1'][idx] = Y_prob_task1[1]
405
        file['label_task2'][idx] = label_task2
406
        file['Y_hat_task2'][idx] = Y_hat_task2
407
        file['Y_prob_task2'][idx] = Y_prob_task2[1]
408
        file['label_task3'][idx] = label_task3
409
        file['Y_hat_task3'][idx] = Y_hat_task3
410
        file['Y_prob_task3'][idx] = Y_prob_task3[1]
411
412
413
414
def initialize_features_hdf5_file(file_path, length, feature_dim=512, names = None):
415
416
    file = h5py.File(file_path, "w")
417
418
    dset = file.create_dataset('features_task1',
419
                                shape=(length, feature_dim), chunks=(1, feature_dim), dtype=np.float32)
420
    dset = file.create_dataset('features_task2',
421
                                shape=(length, feature_dim), chunks=(1, feature_dim), dtype=np.float32)
422
    dset = file.create_dataset('features_task3',
423
                                shape=(length, feature_dim), chunks=(1, feature_dim), dtype=np.float32)
424
425
    # if names is not None:
426
    #     names = np.array(names, dtype='S')
427
    #     dset.attrs['names'] = names
428
    if names is not None:
429
        dt = h5py.string_dtype()
430
        label_dset = file.create_dataset('names', shape=(length, ), chunks=(1, ), dtype=dt)
431
        file['names'][:] = names
432
433
    label_dset = file.create_dataset('label_task1', shape=(length, ), chunks=(1, ), dtype=np.int32)
434
    pred_dset = file.create_dataset( 'Y_hat_task1', shape=(length, ), chunks=(1, ), dtype=np.int32)
435
    prob_dset = file.create_dataset( 'Y_prob_task1', shape=(length, ), chunks=(1, ), dtype=np.float32)
436
    label_dset = file.create_dataset('label_task2', shape=(length, ), chunks=(1, ), dtype=np.int32)
437
    pred_dset = file.create_dataset( 'Y_hat_task2', shape=(length, ), chunks=(1, ), dtype=np.int32)
438
    prob_dset = file.create_dataset( 'Y_prob_task2', shape=(length, ), chunks=(1, ), dtype=np.float32)
439
    label_dset = file.create_dataset('label_task3', shape=(length, ), chunks=(1, ), dtype=np.int32)
440
    pred_dset = file.create_dataset( 'Y_hat_task3', shape=(length, ), chunks=(1, ), dtype=np.int32)
441
    prob_dset = file.create_dataset( 'Y_prob_task3', shape=(length, ), chunks=(1, ), dtype=np.float32)
442
443
    file.close()
444
    return file_path