Diff of /scripts/evaluate.py [000000] .. [d9566e]

Switch to unified view

a b/scripts/evaluate.py
1
from os.path import dirname, realpath
2
import sys
3
sys.path.append(dirname(dirname(realpath(__file__))))
4
import argparse
5
from argparse import Namespace
6
import pickle
7
from sandstone.learn.metrics.factory import get_metric
8
import numpy as np
9
import os
10
11
LOGGER_KEYS = ['censors', 'golds'] 
12
13
def make_logging_dict(results):
14
    logging_dict = {}
15
    logging_dict = {k: np.array( results[0].get('test_{}'.format(k), 0) ) for k in LOGGER_KEYS}
16
    logging_dict['probs'] = []
17
    prob_key = [k.split('_probs')[0] for k in results[0].keys() if 'probs' in k][0]
18
    
19
    logging_dict['probs'] = np.mean([np.array(r['{}_probs'.format(prob_key)])  for r in results ], axis = 0)
20
    
21
    return logging_dict
22
23
parser = argparse.ArgumentParser()
24
parser.add_argument('--parent_dir', type = str, default = '/Mounts/rbg-storage1/logs/lung_ct/')
25
parser.add_argument('--result_file_names', type = str, nargs = '+', default =  ["7a07ee56c93e2abd100a47542e394bed","1e94034923b44462203cfcf29ae29061","9ff7a8ba3b1f9eb7216f35222b3b8524","18490328d1790f7b6b8c86d97f25103c"])
26
parser.add_argument('--test_suffix', type = str, default = 'test')
27
parser.add_argument('--metric_name', nargs= '*', type = str, default = 'survival')
28
29
30
if __name__ == '__main__':
31
    args = parser.parse_args()
32
    result_args = [Namespace(**pickle.load(open(os.path.join(args.parent_dir, '{}.results'.format(f)), 'rb'))) for f in args.result_file_names]
33
    test_full_paths = [os.path.join(args.parent_dir, '{}.results.{}_.predictions'.format(f, args.test_suffix)) for f in args.result_file_names]
34
    test_results = [pickle.load(open(f, 'rb')) for f in  test_full_paths]
35
36
    logging_dict = make_logging_dict(test_results)
37
    performance_dict = {}
38
    metrics = [get_metric(m) for m in args.metric_name]
39
    for m in metrics:
40
        performance_dict.update( m(logging_dict, result_args[0]) )
41
42
    print(performance_dict)
43