Switch to unified view

a b/preprocessing/4-combine.py
1
import json
2
import sparse
3
import pandas as pd
4
import numpy as np
5
import scipy.sparse
6
import joblib
7
8
def load_IDs(fname):
9
    IDs = pd.read_csv(fname, header=0, names=['ID'])
10
    IDs.index.name = 'i'
11
    IDs = IDs.reset_index()
12
    return IDs
13
14
def _get_feature_set(df, X_ALL, IDs_ALL):
15
    IDs = df.set_index('ID')[[]]
16
    idx = IDs.join(IDs_ALL.set_index('ID')).astype(float)
17
    X = [X_ALL[int(i),:] if not np.isnan(i) else sparse.zeros(X_ALL.shape[1]) for i in idx.values]
18
    return sparse.stack(X)
19
20
def get_features(df, feature_sets):
21
    features = []
22
    feature_names = []
23
    if 'demog' in feature_sets:
24
        X_d = df.set_index('hosp_id')[['ID']].join(df_demog).reset_index(drop=True).set_index('ID').loc[df['ID']]
25
        X_d = sparse.as_coo(X_d.values)
26
        features.append(X_d)
27
        feature_names.append(names_demog)
28
        print('demog - Done')
29
    if 'vitals' in feature_sets:
30
        X_v = _get_feature_set(df, X_vitals, IDs_vitals)
31
        features.append(X_v)
32
        feature_names.append(names_vitals)
33
        print('vitals - Done')
34
    if 'meds' in feature_sets:
35
        X_m = _get_feature_set(df, X_meds, IDs_meds)
36
        features.append(X_m)
37
        feature_names.append(names_meds)
38
        print('meds - Done')
39
    if 'labs' in feature_sets:
40
        X_l = _get_feature_set(df, X_labs, IDs_labs)
41
        features.append(X_l)
42
        feature_names.append(names_labs)
43
        print('labs - Done')
44
    if 'flow' in feature_sets:
45
        print('flow', end='')
46
        X_f = _get_feature_set(df, X_flow, IDs_flow)
47
        features.append(X_f)
48
        feature_names.append(names_flow)
49
        print(' - Done')
50
    X = sparse.concatenate(features, axis=1).tocsr()
51
    feature_names = sum(feature_names, [])
52
    return X, np.array(feature_names)
53
54
55
if __name__ == '__main__':
56
57
    df_demog = pd.read_csv('sample_output/out_demog/static-features.csv').set_index('hosp_id')
58
    names_demog = list(df_demog.columns)
59
    print('demog - Loaded')
60
61
    X_vitals = sparse.load_npz('sample_output/out_vitals/X_all.npz')
62
    IDs_vitals = load_IDs('sample_output/out_vitals/X_all.IDs.csv')
63
    names_vitals = json.load(open('metadata/vitals/X_all.feature_names.json', 'r'))
64
    print('vitals - Loaded')
65
66
    X_meds = sparse.load_npz('sample_output/out_meds/X_all.npz')
67
    IDs_meds = load_IDs('sample_output/out_meds/X_all.IDs.csv')
68
    names_meds = json.load(open('metadata/meds/X_all.feature_names.json', 'r'))
69
    print('meds - Loaded')
70
71
    X_labs = sparse.load_npz('sample_output/out_labs/X_all.npz')
72
    IDs_labs = load_IDs('sample_output/out_labs/X_all.IDs.csv')
73
    names_labs = json.load(open('metadata/labs/X_all.feature_names.json', 'r'))
74
    print('labs - Loaded')
75
76
    X_flow = sparse.load_npz('sample_output/out_flow/X_all.npz')
77
    IDs_flow = load_IDs('sample_output/out_flow/X_all.IDs.csv')
78
    names_flow = json.load(open('metadata/flow/X_all.feature_names.json', 'r'))
79
    print('flow - Loaded')
80
81
    df_cohort = pd.read_csv('sample_input/windows_map.csv')
82
    X, names = get_features(df_cohort, ['demog', 'vitals', 'meds', 'labs', 'flow'])
83
    df_features = pd.DataFrame(X.todense(), columns=names, index=df_cohort['ID'])
84
    pd.Series(names).rename('feature_name').to_csv('./sample_output/feature_names.csv', index=False)
85
86
    ## Full feature matrix
87
    joblib.dump(df_features, 'sample_output/full.joblib')
88
89
    ## Baseline features
90
    baseline_cols = pd.read_csv('metadata/Baseline_Feature_Names.txt', sep='\t', header=None)[0].values
91
    df_baseline = df_features[baseline_cols]
92
    df_baseline.to_csv('sample_output/baseline.csv')
93
94
    ## M-CURES (lite)
95
    mcures_cols = pd.read_csv('metadata/MCURES_Feature_Names.txt', sep='\t', header=None)[0].values
96
    df_mcures = df_features[mcures_cols]
97
    df_mcures.to_csv('sample_output/mcures.csv')