Diff of /examples/viz.py [000000] .. [ae9c43]

Switch to unified view

a b/examples/viz.py
1
"""
2
Data visualization of patient features and targets
3
4
>>> %reload_ext autoreload
5
>>> %autoreload 2
6
"""
7
import pandas as pd
8
import matplotlib.pyplot as plt
9
import seaborn as sns
10
11
from prescreen.vcare import targets, vital_signs
12
from sklearn.manifold import TSNE
13
from sklearn.ensemble import RandomForestClassifier
14
from sklearn.model_selection import train_test_split
15
from sklearn.metrics import roc_auc_score, roc_curve
16
17
sns.set()
18
19
20
df: pd.DataFrame = targets.fetch()
21
signs: pd.DataFrame = vital_signs.fetch()
22
23
cols_target, cols_signs = list(df.columns), list(signs.columns)
24
25
n_SF = (df['screenfail'] == 1).sum()
26
27
df_join = df.merge(signs, 'outer', on=None, left_on='id',
28
                        right_on='patient_id')
29
df_join.drop(['nip_x', 'nip_y', 'patient_id'], axis=1, inplace=True)
30
31
32
# pandas plot
33
df.plot(y='time', kind='hist', title='distribution of survival time',
34
        bins=128)
35
plt.show()
36
37
df.plot(y='status', kind='hist')
38
df.plot(y='screenfail', kind='hist')
39
plt.title('Distribution des Screen Failures')
40
plt.show()
41
42
print(signs.count(numeric_only=True))
43
44
45
# seaborn plots
46
sns.stripplot(x='status', y='time', data=df, jitter=True)
47
plt.show()
48
49
sns.countplot(x='screenfail', data=df)
50
plt.title('distribution of failures (n={})'.format(n_SF))
51
plt.show()
52
53
sns.distplot(a=signs['mean_dbp'].fillna(-1))
54
plt.title('mean diastelic blood pressure distribution')
55
plt.show()
56
57
vars= {'mean_dbp': df_join['mean_dbp'].dropna().mean(),
58
       'mean_sdp': df_join['mean_sdp'].dropna().mean(),
59
       'mean_ecog': df_join['mean_ecog'].dropna().mean(),
60
       'mean_height': df_join['mean_height'].dropna().mean(),
61
       'mean_weight': df_join['mean_weight'].dropna().mean(),
62
       'mean_temperature': df_join['mean_temperature'].dropna().mean(),
63
       'mean_eva': df_join['mean_eva'].dropna().mean(),
64
       'mean_sato2': df_join['mean_sato2'].dropna().mean(),
65
       'mean_pulse': df_join['mean_pulse'].dropna().mean(),
66
       'mean_frequency': df_join['mean_frequency'].dropna().mean()}
67
68
69
plot = sns.pairplot(data=df_join.dropna(), hue='screenfail',
70
             vars=list(vars.keys()), size=4)
71
# plot.savefig('data/pairplot_signs.png')
72
plt.show(plot)
73
74
75
X_embed = TSNE().fit_transform(df_join.fillna(value=vars)[list(vars.keys())])
76
df_join['embed1'] = X_embed[:, 0]
77
df_join['embed2'] = X_embed[:, 1]
78
79
80
plt.scatter(x= df_join['embed1'], y=df_join['embed2'],
81
            c= df_join['screenfail'],
82
            cmap='viridis', s=20)
83
plt.colorbar()
84
plt.show()
85
86
87
88
# df_filled = df_join.fillna(value=vars)
89
#
90
# x_train, x_test, y_train, y_test = \
91
#     train_test_split(df_filled[list(vars.keys())],
92
#                      df_join['screenfail'].fillna(0), test_size=0.33,
93
#                      stratify=df_join['screenfail'].fillna(0))
94
#
95
# clf = RandomForestClassifier()
96
# clf.fit(x_train, y_train)
97
#
98
# score = clf.score(x_test, y_test)
99
# auc = roc_auc_score(y_test, clf.predict_proba(x_test)[:, 1])
100
#
101
# print('accuracy: {} \n auc: {}'.format(score, auc))
102
#
103
# fpr, tpr, _ = roc_curve(y_test, clf.predict_proba(x_test)[:, 1])
104
#
105
# plt.title('baseline ROC for vital signs')
106
# plt.plot(fpr, tpr, 'b',
107
# label='AUC = %0.2f'% auc)
108
# plt.legend(loc='lower right')
109
# plt.plot([0,1],[0,1],'r--')
110
# plt.xlim([-0.1,1.2])
111
# plt.ylim([-0.1,1.2])
112
# plt.ylabel('True Positive Rate')
113
# plt.xlabel('False Positive Rate')
114
# plt.show()
115
#
116
#
117
# y_pred = clf.predict(x_test)
118
# y_prob = clf.predict_proba(x_test)
119
#
120
# y_pred.sum()