Switch to unified view

a b/scripts/plcom2012/evaluate.py
1
import pickle
2
import os
3
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"
4
from os.path import dirname, realpath
5
import sys
6
sys.path.append(dirname(dirname(dirname(realpath(__file__)))))
7
8
from sybil.parsing import parse_args
9
from scripts.plcom2012.plcom2012 import PLCOm2012
10
from sybil.utils.helpers import get_dataset
11
import sybil.utils.loading as loaders
12
13
def main(args):
14
    # Load dataset and add dataset specific information to args
15
    print("\nLoading data...")
16
    test_data = loaders.get_eval_dataset_loader(
17
            args,
18
            get_dataset(args.dataset, 'test', args),
19
            False
20
            )
21
    
22
    model = PLCOm2012(args)
23
24
    print("\nParameters:")
25
    for attr, value in sorted(args.__dict__.items()):
26
        if attr not in ['optimizer_state', 'patient_to_partition_dict', 'path_to_hidden_dict', 'exam_to_year_dict', 'exam_to_device_dict', 'treatment_to_index','drug_to_y']:
27
            print("\t{}={}".format(attr.upper(), value))
28
29
    print("-------------\nTesting on PLCOm2012")
30
    model.save_prefix = 'test_'
31
    model.test(test_data)
32
    
33
    print("Saving args to {}".format(args.results_path))
34
    pickle.dump(vars(args), open(args.results_path,'wb'))
35
36
if __name__ == '__main__':
37
    __spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)"
38
    args = parse_args()
39
    main(args)