|
a |
|
b/scripts/plcom2012/evaluate.py |
|
|
1 |
import pickle |
|
|
2 |
import os |
|
|
3 |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3" |
|
|
4 |
from os.path import dirname, realpath |
|
|
5 |
import sys |
|
|
6 |
sys.path.append(dirname(dirname(dirname(realpath(__file__))))) |
|
|
7 |
|
|
|
8 |
from sybil.parsing import parse_args |
|
|
9 |
from scripts.plcom2012.plcom2012 import PLCOm2012 |
|
|
10 |
from sybil.utils.helpers import get_dataset |
|
|
11 |
import sybil.utils.loading as loaders |
|
|
12 |
|
|
|
13 |
def main(args): |
|
|
14 |
# Load dataset and add dataset specific information to args |
|
|
15 |
print("\nLoading data...") |
|
|
16 |
test_data = loaders.get_eval_dataset_loader( |
|
|
17 |
args, |
|
|
18 |
get_dataset(args.dataset, 'test', args), |
|
|
19 |
False |
|
|
20 |
) |
|
|
21 |
|
|
|
22 |
model = PLCOm2012(args) |
|
|
23 |
|
|
|
24 |
print("\nParameters:") |
|
|
25 |
for attr, value in sorted(args.__dict__.items()): |
|
|
26 |
if attr not in ['optimizer_state', 'patient_to_partition_dict', 'path_to_hidden_dict', 'exam_to_year_dict', 'exam_to_device_dict', 'treatment_to_index','drug_to_y']: |
|
|
27 |
print("\t{}={}".format(attr.upper(), value)) |
|
|
28 |
|
|
|
29 |
print("-------------\nTesting on PLCOm2012") |
|
|
30 |
model.save_prefix = 'test_' |
|
|
31 |
model.test(test_data) |
|
|
32 |
|
|
|
33 |
print("Saving args to {}".format(args.results_path)) |
|
|
34 |
pickle.dump(vars(args), open(args.results_path,'wb')) |
|
|
35 |
|
|
|
36 |
if __name__ == '__main__': |
|
|
37 |
__spec__ = "ModuleSpec(name='builtins', loader=<class '_frozen_importlib.BuiltinImporter'>)" |
|
|
38 |
args = parse_args() |
|
|
39 |
main(args) |