Diff of /examples/viz.py [000000] .. [ae9c43]

Switch to side-by-side view

--- a
+++ b/examples/viz.py
@@ -0,0 +1,120 @@
+"""
+Data visualization of patient features and targets
+
+>>> %reload_ext autoreload
+>>> %autoreload 2
+"""
+import pandas as pd
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+from prescreen.vcare import targets, vital_signs
+from sklearn.manifold import TSNE
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.model_selection import train_test_split
+from sklearn.metrics import roc_auc_score, roc_curve
+
+sns.set()
+
+
+df: pd.DataFrame = targets.fetch()
+signs: pd.DataFrame = vital_signs.fetch()
+
+cols_target, cols_signs = list(df.columns), list(signs.columns)
+
+n_SF = (df['screenfail'] == 1).sum()
+
+df_join = df.merge(signs, 'outer', on=None, left_on='id',
+                        right_on='patient_id')
+df_join.drop(['nip_x', 'nip_y', 'patient_id'], axis=1, inplace=True)
+
+
+# pandas plot
+df.plot(y='time', kind='hist', title='distribution of survival time',
+        bins=128)
+plt.show()
+
+df.plot(y='status', kind='hist')
+df.plot(y='screenfail', kind='hist')
+plt.title('Distribution des Screen Failures')
+plt.show()
+
+print(signs.count(numeric_only=True))
+
+
+# seaborn plots
+sns.stripplot(x='status', y='time', data=df, jitter=True)
+plt.show()
+
+sns.countplot(x='screenfail', data=df)
+plt.title('distribution of failures (n={})'.format(n_SF))
+plt.show()
+
+sns.distplot(a=signs['mean_dbp'].fillna(-1))
+plt.title('mean diastelic blood pressure distribution')
+plt.show()
+
+vars= {'mean_dbp': df_join['mean_dbp'].dropna().mean(),
+       'mean_sdp': df_join['mean_sdp'].dropna().mean(),
+       'mean_ecog': df_join['mean_ecog'].dropna().mean(),
+       'mean_height': df_join['mean_height'].dropna().mean(),
+       'mean_weight': df_join['mean_weight'].dropna().mean(),
+       'mean_temperature': df_join['mean_temperature'].dropna().mean(),
+       'mean_eva': df_join['mean_eva'].dropna().mean(),
+       'mean_sato2': df_join['mean_sato2'].dropna().mean(),
+       'mean_pulse': df_join['mean_pulse'].dropna().mean(),
+       'mean_frequency': df_join['mean_frequency'].dropna().mean()}
+
+
+plot = sns.pairplot(data=df_join.dropna(), hue='screenfail',
+             vars=list(vars.keys()), size=4)
+# plot.savefig('data/pairplot_signs.png')
+plt.show(plot)
+
+
+X_embed = TSNE().fit_transform(df_join.fillna(value=vars)[list(vars.keys())])
+df_join['embed1'] = X_embed[:, 0]
+df_join['embed2'] = X_embed[:, 1]
+
+
+plt.scatter(x= df_join['embed1'], y=df_join['embed2'],
+            c= df_join['screenfail'],
+            cmap='viridis', s=20)
+plt.colorbar()
+plt.show()
+
+
+
+# df_filled = df_join.fillna(value=vars)
+#
+# x_train, x_test, y_train, y_test = \
+#     train_test_split(df_filled[list(vars.keys())],
+#                      df_join['screenfail'].fillna(0), test_size=0.33,
+#                      stratify=df_join['screenfail'].fillna(0))
+#
+# clf = RandomForestClassifier()
+# clf.fit(x_train, y_train)
+#
+# score = clf.score(x_test, y_test)
+# auc = roc_auc_score(y_test, clf.predict_proba(x_test)[:, 1])
+#
+# print('accuracy: {} \n auc: {}'.format(score, auc))
+#
+# fpr, tpr, _ = roc_curve(y_test, clf.predict_proba(x_test)[:, 1])
+#
+# plt.title('baseline ROC for vital signs')
+# plt.plot(fpr, tpr, 'b',
+# label='AUC = %0.2f'% auc)
+# plt.legend(loc='lower right')
+# plt.plot([0,1],[0,1],'r--')
+# plt.xlim([-0.1,1.2])
+# plt.ylim([-0.1,1.2])
+# plt.ylabel('True Positive Rate')
+# plt.xlabel('False Positive Rate')
+# plt.show()
+#
+#
+# y_pred = clf.predict(x_test)
+# y_prob = clf.predict_proba(x_test)
+#
+# y_pred.sum()