Diff of /performance.py [000000] .. [0f1df3]

Switch to unified view

a b/performance.py
1
import os
2
from pathlib import Path
3
4
import pandas as pd
5
import torch
6
7
from metrics import get_all_metrics, get_regression_metrics
8
9
def export_performance(
10
    src_path: str,
11
    dst_root: str='performance',
12
):
13
    logits = pd.read_pickle(src_path)
14
    config = logits['config']
15
    if config['task'] == 'multitask':
16
        _labels = logits['labels']
17
        _preds = logits['preds']
18
        labels = []
19
        preds = []
20
        for label, pred in zip(_labels, _preds):
21
            if pred[0] != 0.501:
22
                labels.append(label)
23
                preds.append(pred)
24
        _labels, _preds, labels, preds = torch.tensor(_labels), torch.tensor(_preds), torch.tensor(labels), torch.tensor(preds)
25
        outcome_metrics = get_all_metrics(preds[:, 0], labels[:, 0], 'outcome', None)
26
        readmission_metrics = get_all_metrics(preds[:, 1], labels[:, 1], 'outcome', None)
27
        _outcome_metrics = get_all_metrics(_preds[:, 0], _labels[:, 0], 'outcome', None)
28
        _readmission_metrics = get_all_metrics(_preds[:, 1], _labels[:, 1], 'outcome', None)
29
        data = {'count': [len(_labels), len(labels)] * 2}
30
        data = dict(data, **{k: [v1, v2, v3, v4] for k, v1, v2, v3, v4 in zip(_outcome_metrics.keys(), _outcome_metrics.values(), outcome_metrics.values(), _readmission_metrics.values(), readmission_metrics.values())})
31
        performance = pd.DataFrame(data=data, index=['o all', 'o without unknown samples', 'r all', 'r without unknown samples'])
32
    elif config['task'] == 'los':
33
        _labels = logits['labels']
34
        _preds = logits['preds']
35
        labels = []
36
        preds = []
37
        for label, pred in zip(_labels, _preds):
38
            if pred[0] != 0:
39
                labels.append(label)
40
                preds.append(pred)
41
        data = {'count': [len(_labels), len(labels)]}
42
        _labels = torch.vstack([torch.tensor(label).unsqueeze(1) for label in _labels]).squeeze(-1)
43
        _preds = torch.vstack([torch.tensor(pred).unsqueeze(1) for pred in _preds]).squeeze(-1)
44
        labels = torch.vstack([torch.tensor(label).unsqueeze(1) for label in labels]).squeeze(-1)
45
        preds = torch.vstack([torch.tensor(pred).unsqueeze(1) for pred in preds]).squeeze(-1)
46
        _metrics = get_regression_metrics(_preds, _labels)
47
        metrics = get_regression_metrics(preds, labels)
48
        data = dict(data, **{k: [f'{v1:.2f}', f'{v2:.2f}'] for k, v1, v2 in zip(_metrics.keys(), _metrics.values(), metrics.values())})
49
        performance = pd.DataFrame(data=data, index=['all', 'w/o'])
50
    else:
51
        _labels = logits['labels']
52
        _preds = logits['preds']
53
        _metrics = get_all_metrics(_preds, _labels, 'outcome', None)
54
        labels = []
55
        preds = []
56
        for label, pred in zip(_labels, _preds):
57
            if pred != 0.501:
58
                labels.append(label)
59
                preds.append(pred)
60
        metrics = get_all_metrics(preds, labels, 'outcome', None)
61
        data = {'count': [len(_labels), len(labels)]}
62
        data = dict(data, **{k: [f'{v1 * 100:.2f}', f'{v2 * 100:.2f}'] for k, v1, v2 in zip(_metrics.keys(), _metrics.values(), metrics.values())})
63
    
64
        performance = pd.DataFrame(data=data, index=['all', 'without unknown samples'])
65
    
66
    time = config.get('time', 0)
67
    if time == 0:
68
        time_des = 'upon-discharge'
69
    elif time == 1:
70
        time_des = '1month'
71
    elif time == 2:
72
        time_des = '6months'
73
    dst_path = os.path.join(dst_root, config['dataset'], config['task'], config['model'])
74
    sub_dst_name = f'{config["form"]}_{str(config["n_shot"])}shot_{time_des}'
75
    if config['unit'] is True:
76
        sub_dst_name += '_unit'
77
    if config['reference_range'] is True:
78
        sub_dst_name += '_range'
79
    if config.get('prompt_engineering') is True:
80
        sub_dst_name += '_cot'
81
    impute = config.get('impute')
82
    if impute == 0:
83
        sub_dst_name += '_no_impute'
84
    elif impute == 1:
85
        sub_dst_name += '_impute'
86
    elif impute == 2:
87
        sub_dst_name += '_impute_info'
88
    Path(dst_path).mkdir(parents=True, exist_ok=True)
89
    performance.to_csv(os.path.join(dst_path, f'{sub_dst_name}.csv'))
90
91
if __name__ == '__main__':
92
    for file in [
93
        'logits/tjh/outcome/gpt-3.5-turbo-1106/string_1shot_upon-discharge_unit_range_impute.pkl',
94
        'logits/tjh/outcome/gpt-3.5-turbo-1106/string_1shot_upon-discharge_unit_range_no_impute.pkl'
95
    ]:
96
        export_performance(file)