|
a |
|
b/examples/viz.py |
|
|
1 |
""" |
|
|
2 |
Data visualization of patient features and targets |
|
|
3 |
|
|
|
4 |
>>> %reload_ext autoreload |
|
|
5 |
>>> %autoreload 2 |
|
|
6 |
""" |
|
|
7 |
import pandas as pd |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import seaborn as sns |
|
|
10 |
|
|
|
11 |
from prescreen.vcare import targets, vital_signs |
|
|
12 |
from sklearn.manifold import TSNE |
|
|
13 |
from sklearn.ensemble import RandomForestClassifier |
|
|
14 |
from sklearn.model_selection import train_test_split |
|
|
15 |
from sklearn.metrics import roc_auc_score, roc_curve |
|
|
16 |
|
|
|
17 |
sns.set() |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
df: pd.DataFrame = targets.fetch() |
|
|
21 |
signs: pd.DataFrame = vital_signs.fetch() |
|
|
22 |
|
|
|
23 |
cols_target, cols_signs = list(df.columns), list(signs.columns) |
|
|
24 |
|
|
|
25 |
n_SF = (df['screenfail'] == 1).sum() |
|
|
26 |
|
|
|
27 |
df_join = df.merge(signs, 'outer', on=None, left_on='id', |
|
|
28 |
right_on='patient_id') |
|
|
29 |
df_join.drop(['nip_x', 'nip_y', 'patient_id'], axis=1, inplace=True) |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
# pandas plot |
|
|
33 |
df.plot(y='time', kind='hist', title='distribution of survival time', |
|
|
34 |
bins=128) |
|
|
35 |
plt.show() |
|
|
36 |
|
|
|
37 |
df.plot(y='status', kind='hist') |
|
|
38 |
df.plot(y='screenfail', kind='hist') |
|
|
39 |
plt.title('Distribution des Screen Failures') |
|
|
40 |
plt.show() |
|
|
41 |
|
|
|
42 |
print(signs.count(numeric_only=True)) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
# seaborn plots |
|
|
46 |
sns.stripplot(x='status', y='time', data=df, jitter=True) |
|
|
47 |
plt.show() |
|
|
48 |
|
|
|
49 |
sns.countplot(x='screenfail', data=df) |
|
|
50 |
plt.title('distribution of failures (n={})'.format(n_SF)) |
|
|
51 |
plt.show() |
|
|
52 |
|
|
|
53 |
sns.distplot(a=signs['mean_dbp'].fillna(-1)) |
|
|
54 |
plt.title('mean diastelic blood pressure distribution') |
|
|
55 |
plt.show() |
|
|
56 |
|
|
|
57 |
vars= {'mean_dbp': df_join['mean_dbp'].dropna().mean(), |
|
|
58 |
'mean_sdp': df_join['mean_sdp'].dropna().mean(), |
|
|
59 |
'mean_ecog': df_join['mean_ecog'].dropna().mean(), |
|
|
60 |
'mean_height': df_join['mean_height'].dropna().mean(), |
|
|
61 |
'mean_weight': df_join['mean_weight'].dropna().mean(), |
|
|
62 |
'mean_temperature': df_join['mean_temperature'].dropna().mean(), |
|
|
63 |
'mean_eva': df_join['mean_eva'].dropna().mean(), |
|
|
64 |
'mean_sato2': df_join['mean_sato2'].dropna().mean(), |
|
|
65 |
'mean_pulse': df_join['mean_pulse'].dropna().mean(), |
|
|
66 |
'mean_frequency': df_join['mean_frequency'].dropna().mean()} |
|
|
67 |
|
|
|
68 |
|
|
|
69 |
plot = sns.pairplot(data=df_join.dropna(), hue='screenfail', |
|
|
70 |
vars=list(vars.keys()), size=4) |
|
|
71 |
# plot.savefig('data/pairplot_signs.png') |
|
|
72 |
plt.show(plot) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
X_embed = TSNE().fit_transform(df_join.fillna(value=vars)[list(vars.keys())]) |
|
|
76 |
df_join['embed1'] = X_embed[:, 0] |
|
|
77 |
df_join['embed2'] = X_embed[:, 1] |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
plt.scatter(x= df_join['embed1'], y=df_join['embed2'], |
|
|
81 |
c= df_join['screenfail'], |
|
|
82 |
cmap='viridis', s=20) |
|
|
83 |
plt.colorbar() |
|
|
84 |
plt.show() |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
|
|
|
88 |
# df_filled = df_join.fillna(value=vars) |
|
|
89 |
# |
|
|
90 |
# x_train, x_test, y_train, y_test = \ |
|
|
91 |
# train_test_split(df_filled[list(vars.keys())], |
|
|
92 |
# df_join['screenfail'].fillna(0), test_size=0.33, |
|
|
93 |
# stratify=df_join['screenfail'].fillna(0)) |
|
|
94 |
# |
|
|
95 |
# clf = RandomForestClassifier() |
|
|
96 |
# clf.fit(x_train, y_train) |
|
|
97 |
# |
|
|
98 |
# score = clf.score(x_test, y_test) |
|
|
99 |
# auc = roc_auc_score(y_test, clf.predict_proba(x_test)[:, 1]) |
|
|
100 |
# |
|
|
101 |
# print('accuracy: {} \n auc: {}'.format(score, auc)) |
|
|
102 |
# |
|
|
103 |
# fpr, tpr, _ = roc_curve(y_test, clf.predict_proba(x_test)[:, 1]) |
|
|
104 |
# |
|
|
105 |
# plt.title('baseline ROC for vital signs') |
|
|
106 |
# plt.plot(fpr, tpr, 'b', |
|
|
107 |
# label='AUC = %0.2f'% auc) |
|
|
108 |
# plt.legend(loc='lower right') |
|
|
109 |
# plt.plot([0,1],[0,1],'r--') |
|
|
110 |
# plt.xlim([-0.1,1.2]) |
|
|
111 |
# plt.ylim([-0.1,1.2]) |
|
|
112 |
# plt.ylabel('True Positive Rate') |
|
|
113 |
# plt.xlabel('False Positive Rate') |
|
|
114 |
# plt.show() |
|
|
115 |
# |
|
|
116 |
# |
|
|
117 |
# y_pred = clf.predict(x_test) |
|
|
118 |
# y_prob = clf.predict_proba(x_test) |
|
|
119 |
# |
|
|
120 |
# y_pred.sum() |