|
a |
|
b/performance.py |
|
|
1 |
import os |
|
|
2 |
from pathlib import Path |
|
|
3 |
|
|
|
4 |
import pandas as pd |
|
|
5 |
import torch |
|
|
6 |
|
|
|
7 |
from metrics import get_all_metrics, get_regression_metrics |
|
|
8 |
|
|
|
9 |
def export_performance( |
|
|
10 |
src_path: str, |
|
|
11 |
dst_root: str='performance', |
|
|
12 |
): |
|
|
13 |
logits = pd.read_pickle(src_path) |
|
|
14 |
config = logits['config'] |
|
|
15 |
if config['task'] == 'multitask': |
|
|
16 |
_labels = logits['labels'] |
|
|
17 |
_preds = logits['preds'] |
|
|
18 |
labels = [] |
|
|
19 |
preds = [] |
|
|
20 |
for label, pred in zip(_labels, _preds): |
|
|
21 |
if pred[0] != 0.501: |
|
|
22 |
labels.append(label) |
|
|
23 |
preds.append(pred) |
|
|
24 |
_labels, _preds, labels, preds = torch.tensor(_labels), torch.tensor(_preds), torch.tensor(labels), torch.tensor(preds) |
|
|
25 |
outcome_metrics = get_all_metrics(preds[:, 0], labels[:, 0], 'outcome', None) |
|
|
26 |
readmission_metrics = get_all_metrics(preds[:, 1], labels[:, 1], 'outcome', None) |
|
|
27 |
_outcome_metrics = get_all_metrics(_preds[:, 0], _labels[:, 0], 'outcome', None) |
|
|
28 |
_readmission_metrics = get_all_metrics(_preds[:, 1], _labels[:, 1], 'outcome', None) |
|
|
29 |
data = {'count': [len(_labels), len(labels)] * 2} |
|
|
30 |
data = dict(data, **{k: [v1, v2, v3, v4] for k, v1, v2, v3, v4 in zip(_outcome_metrics.keys(), _outcome_metrics.values(), outcome_metrics.values(), _readmission_metrics.values(), readmission_metrics.values())}) |
|
|
31 |
performance = pd.DataFrame(data=data, index=['o all', 'o without unknown samples', 'r all', 'r without unknown samples']) |
|
|
32 |
elif config['task'] == 'los': |
|
|
33 |
_labels = logits['labels'] |
|
|
34 |
_preds = logits['preds'] |
|
|
35 |
labels = [] |
|
|
36 |
preds = [] |
|
|
37 |
for label, pred in zip(_labels, _preds): |
|
|
38 |
if pred[0] != 0: |
|
|
39 |
labels.append(label) |
|
|
40 |
preds.append(pred) |
|
|
41 |
data = {'count': [len(_labels), len(labels)]} |
|
|
42 |
_labels = torch.vstack([torch.tensor(label).unsqueeze(1) for label in _labels]).squeeze(-1) |
|
|
43 |
_preds = torch.vstack([torch.tensor(pred).unsqueeze(1) for pred in _preds]).squeeze(-1) |
|
|
44 |
labels = torch.vstack([torch.tensor(label).unsqueeze(1) for label in labels]).squeeze(-1) |
|
|
45 |
preds = torch.vstack([torch.tensor(pred).unsqueeze(1) for pred in preds]).squeeze(-1) |
|
|
46 |
_metrics = get_regression_metrics(_preds, _labels) |
|
|
47 |
metrics = get_regression_metrics(preds, labels) |
|
|
48 |
data = dict(data, **{k: [f'{v1:.2f}', f'{v2:.2f}'] for k, v1, v2 in zip(_metrics.keys(), _metrics.values(), metrics.values())}) |
|
|
49 |
performance = pd.DataFrame(data=data, index=['all', 'w/o']) |
|
|
50 |
else: |
|
|
51 |
_labels = logits['labels'] |
|
|
52 |
_preds = logits['preds'] |
|
|
53 |
_metrics = get_all_metrics(_preds, _labels, 'outcome', None) |
|
|
54 |
labels = [] |
|
|
55 |
preds = [] |
|
|
56 |
for label, pred in zip(_labels, _preds): |
|
|
57 |
if pred != 0.501: |
|
|
58 |
labels.append(label) |
|
|
59 |
preds.append(pred) |
|
|
60 |
metrics = get_all_metrics(preds, labels, 'outcome', None) |
|
|
61 |
data = {'count': [len(_labels), len(labels)]} |
|
|
62 |
data = dict(data, **{k: [f'{v1 * 100:.2f}', f'{v2 * 100:.2f}'] for k, v1, v2 in zip(_metrics.keys(), _metrics.values(), metrics.values())}) |
|
|
63 |
|
|
|
64 |
performance = pd.DataFrame(data=data, index=['all', 'without unknown samples']) |
|
|
65 |
|
|
|
66 |
time = config.get('time', 0) |
|
|
67 |
if time == 0: |
|
|
68 |
time_des = 'upon-discharge' |
|
|
69 |
elif time == 1: |
|
|
70 |
time_des = '1month' |
|
|
71 |
elif time == 2: |
|
|
72 |
time_des = '6months' |
|
|
73 |
dst_path = os.path.join(dst_root, config['dataset'], config['task'], config['model']) |
|
|
74 |
sub_dst_name = f'{config["form"]}_{str(config["n_shot"])}shot_{time_des}' |
|
|
75 |
if config['unit'] is True: |
|
|
76 |
sub_dst_name += '_unit' |
|
|
77 |
if config['reference_range'] is True: |
|
|
78 |
sub_dst_name += '_range' |
|
|
79 |
if config.get('prompt_engineering') is True: |
|
|
80 |
sub_dst_name += '_cot' |
|
|
81 |
impute = config.get('impute') |
|
|
82 |
if impute == 0: |
|
|
83 |
sub_dst_name += '_no_impute' |
|
|
84 |
elif impute == 1: |
|
|
85 |
sub_dst_name += '_impute' |
|
|
86 |
elif impute == 2: |
|
|
87 |
sub_dst_name += '_impute_info' |
|
|
88 |
Path(dst_path).mkdir(parents=True, exist_ok=True) |
|
|
89 |
performance.to_csv(os.path.join(dst_path, f'{sub_dst_name}.csv')) |
|
|
90 |
|
|
|
91 |
if __name__ == '__main__': |
|
|
92 |
for file in [ |
|
|
93 |
'logits/tjh/outcome/gpt-3.5-turbo-1106/string_1shot_upon-discharge_unit_range_impute.pkl', |
|
|
94 |
'logits/tjh/outcome/gpt-3.5-turbo-1106/string_1shot_upon-discharge_unit_range_no_impute.pkl' |
|
|
95 |
]: |
|
|
96 |
export_performance(file) |