|
a |
|
b/scripts/evaluate.py |
|
|
1 |
from os.path import dirname, realpath |
|
|
2 |
import sys |
|
|
3 |
sys.path.append(dirname(dirname(realpath(__file__)))) |
|
|
4 |
import argparse |
|
|
5 |
from argparse import Namespace |
|
|
6 |
import pickle |
|
|
7 |
from sandstone.learn.metrics.factory import get_metric |
|
|
8 |
import numpy as np |
|
|
9 |
import os |
|
|
10 |
|
|
|
11 |
LOGGER_KEYS = ['censors', 'golds'] |
|
|
12 |
|
|
|
13 |
def make_logging_dict(results): |
|
|
14 |
logging_dict = {} |
|
|
15 |
logging_dict = {k: np.array( results[0].get('test_{}'.format(k), 0) ) for k in LOGGER_KEYS} |
|
|
16 |
logging_dict['probs'] = [] |
|
|
17 |
prob_key = [k.split('_probs')[0] for k in results[0].keys() if 'probs' in k][0] |
|
|
18 |
|
|
|
19 |
logging_dict['probs'] = np.mean([np.array(r['{}_probs'.format(prob_key)]) for r in results ], axis = 0) |
|
|
20 |
|
|
|
21 |
return logging_dict |
|
|
22 |
|
|
|
23 |
parser = argparse.ArgumentParser() |
|
|
24 |
parser.add_argument('--parent_dir', type = str, default = '/Mounts/rbg-storage1/logs/lung_ct/') |
|
|
25 |
parser.add_argument('--result_file_names', type = str, nargs = '+', default = ["7a07ee56c93e2abd100a47542e394bed","1e94034923b44462203cfcf29ae29061","9ff7a8ba3b1f9eb7216f35222b3b8524","18490328d1790f7b6b8c86d97f25103c"]) |
|
|
26 |
parser.add_argument('--test_suffix', type = str, default = 'test') |
|
|
27 |
parser.add_argument('--metric_name', nargs= '*', type = str, default = 'survival') |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
if __name__ == '__main__': |
|
|
31 |
args = parser.parse_args() |
|
|
32 |
result_args = [Namespace(**pickle.load(open(os.path.join(args.parent_dir, '{}.results'.format(f)), 'rb'))) for f in args.result_file_names] |
|
|
33 |
test_full_paths = [os.path.join(args.parent_dir, '{}.results.{}_.predictions'.format(f, args.test_suffix)) for f in args.result_file_names] |
|
|
34 |
test_results = [pickle.load(open(f, 'rb')) for f in test_full_paths] |
|
|
35 |
|
|
|
36 |
logging_dict = make_logging_dict(test_results) |
|
|
37 |
performance_dict = {} |
|
|
38 |
metrics = [get_metric(m) for m in args.metric_name] |
|
|
39 |
for m in metrics: |
|
|
40 |
performance_dict.update( m(logging_dict, result_args[0]) ) |
|
|
41 |
|
|
|
42 |
print(performance_dict) |
|
|
43 |
|