Diff of /eval_mtl_concat.py [000000] .. [fdd588]

Switch to unified view

a b/eval_mtl_concat.py
1
from __future__ import print_function
2
3
import numpy as np
4
5
import argparse
6
import torch
7
import torch.nn as nn
8
import pdb
9
import os
10
import pandas as pd
11
from utils.utils import *
12
from math import floor
13
import matplotlib.pyplot as plt
14
from datasets.dataset_mtl_concat import Generic_MIL_MTL_Dataset, save_splits
15
import h5py
16
from utils.eval_utils_mtl_concat import *
17
18
# Training settings
19
parser = argparse.ArgumentParser(description='TOAD Evaluation Script')
20
parser.add_argument('--data_root_dir', type=str, help='data directory')
21
parser.add_argument('--results_dir', type=str, default='./results',
22
                    help='relative path to results folder, i.e. '+
23
                    'the directory containing models_exp_code relative to project root (default: ./results)')
24
parser.add_argument('--save_exp_code', type=str, default=None,
25
                    help='experiment code to save eval results')
26
parser.add_argument('--models_exp_code', type=str, default=None,
27
                    help='experiment code to load trained models (directory under results_dir containing model checkpoints')
28
parser.add_argument('--splits_dir', type=str, default=None,
29
                    help='splits directory, if using custom splits other than what matches the task (default: None)')
30
parser.add_argument('--drop_out', action='store_true', default=False, 
31
                    help='whether model uses dropout')
32
parser.add_argument('--k', type=int, default=1, help='number of folds (default: 1)')
33
parser.add_argument('--k_start', type=int, default=-1, help='start fold (default: -1, last fold)')
34
parser.add_argument('--k_end', type=int, default=-1, help='end fold (default: -1, first fold)')
35
parser.add_argument('--fold', type=int, default=-1, help='single fold to evaluate')
36
parser.add_argument('--micro_average', action='store_true', default=False, 
37
                    help='use micro_average instead of macro_avearge for multiclass AUC')
38
parser.add_argument('--split', type=str, choices=['train', 'val', 'test', 'all'], default='test')
39
parser.add_argument('--task', type=str, choices=['dummy_mtl_concat'])
40
41
args = parser.parse_args()
42
43
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
45
encoding_size = 1024
46
47
args.save_dir = os.path.join('./eval_results', 'EVAL_' + str(args.save_exp_code))
48
args.models_dir = os.path.join(args.results_dir, str(args.models_exp_code))
49
50
os.makedirs(args.save_dir, exist_ok=True)
51
52
if args.splits_dir is None:
53
    args.splits_dir = args.models_dir
54
55
assert os.path.isdir(args.models_dir)
56
assert os.path.isdir(args.splits_dir)
57
58
settings = {'task': args.task,
59
            'split': args.split,
60
            'save_dir': args.save_dir, 
61
            'models_dir': args.models_dir,
62
            'drop_out': args.drop_out,
63
            'micro_avg': args.micro_average}
64
65
with open(args.save_dir + '/eval_experiment_{}.txt'.format(args.save_exp_code), 'w') as f:
66
    print(settings, file=f)
67
f.close()
68
69
print(settings)
70
71
72
if args.task == 'dummy_mtl_concat':
73
    args.n_classes=18
74
    dataset = Generic_MIL_MTL_Dataset(csv_path = 'dataset_csv/dummy_dataset.csv',
75
                            data_dir= os.path.join(args.data_root_dir,'DATASET_DIR'),
76
                            shuffle = False, 
77
                            print_info = True,
78
                            label_dicts = [{'Lung':0, 'Breast':1, 'Colorectal':2, 'Ovarian':3, 
79
                                                                'Pancreatic':4, 'Adrenal':5, 
80
                                                                'Skin':6, 'Prostate':7, 'Renal':8, 'Bladder':9, 
81
                                                                'Esophagogastric':10,  'Thyroid':11,
82
                                                                'Head Neck':12,  'Glioma':13, 
83
                                                                'Germ Cell':14, 'Endometrial': 15, 'Cervix': 16, 'Liver': 17},
84
                                            {'Primary':0,  'Metastatic':1},
85
                                            {'F':0, 'M':1}],
86
                            label_cols = ['label', 'site', 'sex'],
87
                            patient_strat= False)
88
89
else:
90
    raise NotImplementedError
91
92
if args.k_start == -1:
93
    start = 0
94
else:
95
    start = args.k_start
96
if args.k_end == -1:
97
    end = args.k
98
else:
99
    end = args.k_end
100
101
if args.fold == -1:
102
    folds = range(start, end)
103
else:
104
    folds = range(args.fold, args.fold+1)
105
ckpt_paths = [os.path.join(args.models_dir, 's_{}_checkpoint.pt'.format(fold)) for fold in folds]
106
datasets_id = {'train': 0, 'val': 1, 'test': 2, 'all': -1}
107
108
if __name__ == "__main__":
109
110
    all_cls_auc = []
111
    all_cls_acc = []
112
    all_site_auc = []
113
    all_site_acc = []
114
    all_cls_top3_acc = []
115
    all_cls_top5_acc = []
116
    
117
    for ckpt_idx in range(len(ckpt_paths)):
118
        if datasets_id[args.split] < 0:
119
            split_dataset = dataset
120
            csv_path = None
121
        else:
122
            csv_path = '{}/splits_{}.csv'.format(args.splits_dir, folds[ckpt_idx])
123
            datasets = dataset.return_splits(from_id=False, csv_path=csv_path)
124
            split_dataset = datasets[datasets_id[args.split]]
125
126
        model, results_dict = eval(split_dataset, args, ckpt_paths[ckpt_idx])
127
128
        for cls_idx in range(len(results_dict['cls_aucs'])):
129
            print('class {} auc: {}'.format(cls_idx, results_dict['cls_aucs'][cls_idx]))
130
131
        all_cls_auc.append(results_dict['cls_auc'])
132
        all_cls_acc.append(1-results_dict['cls_test_error'])
133
        all_site_auc.append(results_dict['site_auc'])
134
        all_site_acc.append(1-results_dict['site_test_error'])
135
        all_cls_top3_acc.append(results_dict['top3_acc'])
136
        all_cls_top5_acc.append(results_dict['top5_acc'])
137
        df = results_dict['df']
138
        df.to_csv(os.path.join(args.save_dir, 'fold_{}.csv'.format(folds[ckpt_idx])), index=False)
139
140
141
    df_dict = {'folds': folds, 'cls_test_auc': all_cls_auc, 'cls_test_acc': all_cls_acc, 'cls_top3_acc': all_cls_top3_acc, 'cls_top5_acc': all_cls_top5_acc,
142
                'site_test_auc': all_site_auc, 'site_test_acc': all_site_acc}
143
144
    final_df = pd.DataFrame(df_dict)
145
    if len(folds) != args.k:
146
        save_name = 'summary_partial_{}_{}.csv'.format(folds[0], folds[-1])
147
    else:
148
        save_name = 'summary.csv'
149
    final_df.to_csv(os.path.join(args.save_dir, save_name))