Diff of /draw.ipynb [000000] .. [d6904d]

Switch to unified view

a b/draw.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import numpy as np\n",
10
    "import pandas as pd\n",
11
    "from sklearn.calibration import calibration_curve\n",
12
    "from sklearn.metrics import roc_curve, precision_recall_curve\n",
13
    "import torch\n",
14
    "from sklearn.decomposition import PCA\n",
15
    "from sklearn.manifold import TSNE\n",
16
    "# matplotlib\n",
17
    "import matplotlib.pyplot as plt\n",
18
    "import matplotlib.lines as mlines\n",
19
    "import matplotlib.transforms as mtransfor\n",
20
    "from matplotlib.ticker import FuncFormatter\n",
21
    "import seaborn as sns\n",
22
    "\n",
23
    "plt.style.use('default')\n",
24
    "plt.rcParams['axes.facecolor']='white'\n",
25
    "plt.rcParams.update({\"axes.grid\" : True, \"grid.color\": \"gainsboro\"})\n",
26
    "plt.rcParams['legend.frameon']=True\n",
27
    "plt.rcParams['legend.facecolor']='white'\n",
28
    "plt.rcParams['legend.edgecolor']='grey'\n",
29
    "plt.rcParams[\"axes.edgecolor\"] = \"black\"\n",
30
    "plt.rcParams[\"axes.linewidth\"]  = 1"
31
   ]
32
  },
33
  {
34
   "cell_type": "markdown",
35
   "metadata": {},
36
   "source": [
37
    "## Read models' outcome prediction result"
38
   ]
39
  },
40
  {
41
   "cell_type": "code",
42
   "execution_count": null,
43
   "metadata": {},
44
   "outputs": [],
45
   "source": [
46
    "tj_adacare = pd.read_pickle('./saved_pkl/tongji_adacare_outcome.pkl')\n",
47
    "tj_retain = pd.read_pickle('./saved_pkl/tongji_retain_outcome.pkl')\n",
48
    "tj_tcn = pd.read_pickle('./saved_pkl/tongji_tcn_outcome.pkl')\n",
49
    "\n",
50
    "hm_concare = pd.read_pickle('./saved_pkl/hm_concare_outcome.pkl')\n",
51
    "hm_tcn = pd.read_pickle('./saved_pkl/hm_tcn_outcome.pkl')\n",
52
    "hm_rnn = pd.read_pickle('./saved_pkl/hm_rnn_outcome.pkl')\n",
53
    "\n",
54
    "tj_adacare_outcome_true, tj_adacare_outcome_pred = tj_adacare['outcome_true'], tj_adacare['outcome_pred']\n",
55
    "tj_retain_outcome_true, tj_retain_outcome_pred = tj_retain['outcome_true'], tj_retain['outcome_pred']\n",
56
    "tj_tcn_outcome_true, tj_tcn_outcome_pred = tj_tcn['outcome_true'], tj_tcn['outcome_pred']\n",
57
    "\n",
58
    "hm_concare_outcome_true, hm_concare_outcome_pred = hm_concare['outcome_true'], hm_concare['outcome_pred']\n",
59
    "hm_tcn_outcome_true, hm_tcn_outcome_pred = hm_tcn['outcome_true'], hm_tcn['outcome_pred']\n",
60
    "hm_rnn_outcome_true, hm_rnn_outcome_pred = hm_rnn['outcome_true'], hm_rnn['outcome_pred']"
61
   ]
62
  },
63
  {
64
   "cell_type": "markdown",
65
   "metadata": {},
66
   "source": [
67
    "## ROC Plot"
68
   ]
69
  },
70
  {
71
   "cell_type": "code",
72
   "execution_count": null,
73
   "metadata": {},
74
   "outputs": [],
75
   "source": [
76
    "tj_random_probs = [0 for _ in range(len(tj_adacare_outcome_true))]\n",
77
    "tj_p_fpr, tj_p_tpr, _ = roc_curve(tj_adacare_outcome_true, tj_random_probs, pos_label=1)\n",
78
    "\n",
79
    "hm_random_probs = [0 for _ in range(len(hm_tcn_outcome_true))]\n",
80
    "hm_p_fpr, hm_p_tpr, _ = roc_curve(hm_tcn_outcome_true, hm_random_probs, pos_label=1)"
81
   ]
82
  },
83
  {
84
   "cell_type": "code",
85
   "execution_count": null,
86
   "metadata": {},
87
   "outputs": [],
88
   "source": [
89
    "# [TJH] plot roc curves\n",
90
    "\n",
91
    "tj_adacare_fpr, tj_adacare_tpr, thresh1 = roc_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, pos_label=1)\n",
92
    "tj_retain_fpr, tj_retain_tpr, thresh2 = roc_curve(tj_retain_outcome_true, tj_retain_outcome_pred, pos_label=1)\n",
93
    "tj_tcn_fpr, tj_tcn_tpr, thresh3 = roc_curve(tj_tcn_outcome_true, tj_tcn_outcome_pred, pos_label=1)\n",
94
    "\n",
95
    "plt.plot(tj_p_fpr, tj_p_tpr, linestyle='-.', color='grey', label='Random')\n",
96
    "plt.plot(tj_adacare_fpr, tj_adacare_tpr, linestyle='dashed',color='orange', label='AdaCare')\n",
97
    "plt.plot(tj_retain_fpr, tj_retain_tpr, linestyle='solid',color='dodgerblue', label='RETAIN')\n",
98
    "plt.plot(tj_tcn_fpr, tj_tcn_tpr, linestyle='dotted',color='violet', label='TCN')\n",
99
    "\n",
100
    "# # title\n",
101
    "# plt.title('ROC curve')\n",
102
    "# x label\n",
103
    "plt.xlabel('False Positive Rate')\n",
104
    "# y label\n",
105
    "plt.ylabel('True Positive Rate')\n",
106
    "\n",
107
    "plt.legend(loc='lower right')\n",
108
    "\n",
109
    "plt.savefig('tjh_roc.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
110
    "plt.show();"
111
   ]
112
  },
113
  {
114
   "cell_type": "code",
115
   "execution_count": null,
116
   "metadata": {},
117
   "outputs": [],
118
   "source": [
119
    "# [CDSL] plot roc curves\n",
120
    "\n",
121
    "hm_concare_fpr, hm_concare_tpr, thresh1 = roc_curve(hm_concare_outcome_true, hm_concare_outcome_pred, pos_label=1)\n",
122
    "hm_tcn_fpr, hm_tcn_tpr, thresh2 = roc_curve(hm_tcn_outcome_true, hm_tcn_outcome_pred, pos_label=1)\n",
123
    "hm_rnn_fpr, hm_rnn_tpr, thresh3 = roc_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, pos_label=1)\n",
124
    "\n",
125
    "plt.plot(hm_p_fpr, hm_p_tpr, linestyle='-.', color='grey', label='Random')\n",
126
    "plt.plot(hm_concare_fpr, hm_concare_tpr, linestyle='solid',color='dodgerblue', label='ConCare')\n",
127
    "plt.plot(hm_tcn_fpr, hm_tcn_tpr, linestyle='dotted',color='violet', label='TCN')\n",
128
    "plt.plot(hm_rnn_fpr, hm_rnn_tpr, linestyle='dashed',color='orange', label='RNN')\n",
129
    "\n",
130
    "# # title\n",
131
    "# plt.title('ROC curve')\n",
132
    "# x label\n",
133
    "plt.xlabel('False positive rate')\n",
134
    "# y label\n",
135
    "plt.ylabel('True positive rate')\n",
136
    "\n",
137
    "plt.legend(loc='lower right')\n",
138
    "plt.savefig('cdsl_roc.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
139
    "plt.show();"
140
   ]
141
  },
142
  {
143
   "cell_type": "markdown",
144
   "metadata": {},
145
   "source": [
146
    "## PRC Plot"
147
   ]
148
  },
149
  {
150
   "cell_type": "code",
151
   "execution_count": null,
152
   "metadata": {},
153
   "outputs": [],
154
   "source": [
155
    "# [TJH] plot precision-recall curves\n",
156
    "\n",
157
    "tj_adacare_precision, tj_adacare_recall, thresh1 = precision_recall_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, pos_label=1)\n",
158
    "tj_retain_precision, tj_retain_recall, thresh2 = precision_recall_curve(tj_retain_outcome_true, tj_retain_outcome_pred, pos_label=1)\n",
159
    "tj_tcn_precision, tj_tcn_recall, thresh3 = precision_recall_curve(tj_tcn_outcome_true, tj_tcn_outcome_pred, pos_label=1)\n",
160
    "\n",
161
    "plt.plot(tj_adacare_precision, tj_adacare_recall, linestyle='dashed',color='orange', label='AdaCare')\n",
162
    "plt.plot(tj_retain_precision, tj_retain_recall, linestyle='solid',color='dodgerblue', label='RETAIN')\n",
163
    "plt.plot(tj_tcn_precision, tj_tcn_recall, linestyle='dotted',color='violet', label='TCN')\n",
164
    "\n",
165
    "# # title\n",
166
    "# plt.title('PRC curve')\n",
167
    "# x label\n",
168
    "plt.xlabel('Recall')\n",
169
    "# y label\n",
170
    "plt.ylabel('Precision')\n",
171
    "\n",
172
    "plt.legend(loc='lower left')\n",
173
    "plt.savefig('tjh_prc.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
174
    "plt.show();"
175
   ]
176
  },
177
  {
178
   "cell_type": "code",
179
   "execution_count": null,
180
   "metadata": {},
181
   "outputs": [],
182
   "source": [
183
    "# [CDSL] plot precision-recall curves\n",
184
    "\n",
185
    "hm_concare_precision, hm_concare_recall, thresh1 = precision_recall_curve(hm_concare_outcome_true, hm_concare_outcome_pred, pos_label=1)\n",
186
    "hm_tcn_precision, hm_tcn_recall, thresh2 = precision_recall_curve(hm_tcn_outcome_true, hm_tcn_outcome_pred, pos_label=1)\n",
187
    "hm_rnn_precision, hm_rnn_recall, thresh3 = precision_recall_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, pos_label=1)\n",
188
    "\n",
189
    "plt.plot(hm_concare_precision, hm_concare_recall, linestyle='solid',color='dodgerblue', label='ConCare')\n",
190
    "plt.plot(hm_tcn_precision, hm_tcn_recall, linestyle='dotted',color='violet', label='TCN')\n",
191
    "plt.plot(hm_rnn_precision, hm_rnn_recall, linestyle='dashed',color='orange', label='RNN')\n",
192
    "\n",
193
    "# # title\n",
194
    "# plt.title('PRC curve')\n",
195
    "# x label\n",
196
    "plt.xlabel('Recall')\n",
197
    "# y label\n",
198
    "plt.ylabel('Precision')\n",
199
    "\n",
200
    "plt.legend(loc='lower left')\n",
201
    "plt.savefig('cdsl_prc.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
202
    "plt.show();"
203
   ]
204
  },
205
  {
206
   "cell_type": "markdown",
207
   "metadata": {},
208
   "source": [
209
    "## Calibration Plot"
210
   ]
211
  },
212
  {
213
   "cell_type": "code",
214
   "execution_count": null,
215
   "metadata": {},
216
   "outputs": [],
217
   "source": [
218
    "tj_adacare_prob_true, tj_adacare_prob_pred = calibration_curve(tj_adacare_outcome_true, tj_adacare_outcome_pred, n_bins=10)\n",
219
    "tj_retain_prob_true, tj_retain_prob_pred = calibration_curve(tj_retain_outcome_true, tj_retain_outcome_pred, n_bins=10)\n",
220
    "tj_tcn_prob_true, tj_tcn_prob_pred = calibration_curve(tj_tcn_outcome_true, tj_tcn_outcome_pred, n_bins=10)\n",
221
    "\n",
222
    "fig, ax = plt.subplots()\n",
223
    "# only these two lines are calibration curves\n",
224
    "plt.plot(tj_adacare_prob_pred, tj_adacare_prob_true, marker='o', linewidth=1, label='AdaCare')\n",
225
    "plt.plot(tj_retain_prob_pred, tj_retain_prob_true, marker='v', linewidth=1, label='RETAIN')\n",
226
    "plt.plot(tj_tcn_prob_pred, tj_tcn_prob_true, marker='s', linewidth=1, label='TCN')\n",
227
    "\n",
228
    "# reference line, legends, and axis labels\n",
229
    "line = mlines.Line2D([0, 1], [0, 1], linestyle='-.', color='grey')\n",
230
    "transform = ax.transAxes\n",
231
    "line.set_transform(transform)\n",
232
    "ax.add_line(line)\n",
233
    "ax.set_xlabel('Predicted probability')\n",
234
    "ax.set_ylabel('True probability in each bin')\n",
235
    "plt.legend(loc='lower right')\n",
236
    "plt.savefig('tjh_calibration.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
237
    "plt.show()"
238
   ]
239
  },
240
  {
241
   "cell_type": "code",
242
   "execution_count": null,
243
   "metadata": {},
244
   "outputs": [],
245
   "source": [
246
    "hm_concare_prob_true, hm_concare_prob_pred = calibration_curve(hm_concare_outcome_true, hm_concare_outcome_pred, n_bins=10)\n",
247
    "hm_tcn_prob_true, hm_tcn_prob_pred = calibration_curve(hm_tcn_outcome_true, hm_tcn_outcome_pred, n_bins=10)\n",
248
    "hm_rnn_prob_true, hm_rnn_prob_pred = calibration_curve(hm_rnn_outcome_true, hm_rnn_outcome_pred, n_bins=10)\n",
249
    "\n",
250
    "fig, ax = plt.subplots()\n",
251
    "# only these two lines are calibration curves\n",
252
    "plt.plot(hm_concare_prob_pred, hm_concare_prob_true, marker='o', linewidth=1, label='ConCare')\n",
253
    "plt.plot(hm_tcn_prob_pred, hm_tcn_prob_true, marker='s', linewidth=1, label='TCN')\n",
254
    "plt.plot(hm_rnn_prob_pred, hm_rnn_prob_true, marker='v', linewidth=1, label='RNN')\n",
255
    "\n",
256
    "# reference line, legends, and axis labels\n",
257
    "line = mlines.Line2D([0, 1], [0, 1], linestyle='-.', color='grey')\n",
258
    "transform = ax.transAxes\n",
259
    "line.set_transform(transform)\n",
260
    "ax.add_line(line)\n",
261
    "ax.set_xlabel('Predicted probability')\n",
262
    "ax.set_ylabel('True probability in each bin')\n",
263
    "plt.legend(loc='lower right')\n",
264
    "plt.savefig('cdsl_calibration.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
265
    "plt.show()"
266
   ]
267
  },
268
  {
269
   "cell_type": "markdown",
270
   "metadata": {},
271
   "source": [
272
    "## Draw OSMAE/EMP scores on different threshold"
273
   ]
274
  },
275
  {
276
   "cell_type": "code",
277
   "execution_count": null,
278
   "metadata": {},
279
   "outputs": [],
280
   "source": [
281
    "covid_scores = pd.read_pickle('./saved_pkl/covid_evaluation_scores.pkl')\n",
282
    "emp, osmae, thresholds = covid_scores[\"emp\"][1::4], covid_scores[\"osmae\"][1::4], covid_scores[\"threshold\"][1::4]"
283
   ]
284
  },
285
  {
286
   "cell_type": "code",
287
   "execution_count": null,
288
   "metadata": {},
289
   "outputs": [],
290
   "source": [
291
    "## EMP Score\n",
292
    "ax = sns.regplot(x=thresholds, y=emp, marker=\"o\", color=\"g\", line_kws={\"color\": \"grey\", \"linestyle\": \"-\", \"linewidth\": \"1\"}, ci=99.9999)\n",
293
    "plt.xlabel('Threshold γ')\n",
294
    "plt.ylabel('ES score')\n",
295
    "\n",
296
    "plt.savefig('emp_trend.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
297
    "plt.show();"
298
   ]
299
  },
300
  {
301
   "cell_type": "code",
302
   "execution_count": null,
303
   "metadata": {},
304
   "outputs": [],
305
   "source": [
306
    "## OSMAE Score\n",
307
    "ax = sns.regplot(x=thresholds, y=osmae, marker=\"o\", color=\"dodgerblue\", line_kws={\"color\": \"grey\", \"linestyle\": \"-\", \"linewidth\": \"1\"}, ci=99.9999)\n",
308
    "plt.xlabel('Threshold γ')\n",
309
    "plt.ylabel('OSMAE score')\n",
310
    "\n",
311
    "plt.savefig('osmae_trend.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
312
    "plt.show();"
313
   ]
314
  },
315
  {
316
   "cell_type": "markdown",
317
   "metadata": {},
318
   "source": [
319
    "## Draw hidden state PCA result on validation set (CDSL dataset)"
320
   ]
321
  },
322
  {
323
   "cell_type": "code",
324
   "execution_count": null,
325
   "metadata": {},
326
   "outputs": [],
327
   "source": [
328
    "val_idx=[1010, 1915, 1656, 2952, 246, 2914, 3146, 2910, 914, 1335, 3046, 404, 2592, 2951, 309, 266, 471, 1112, 490, 3195, 2621, 2143, 1485, 893, 2803, 319, 3231, 2185, 771, 2811, 1950, 3615, 2537, 2546, 3750, 2284, 2122, 1817, 767, 2698, 1564, 3519, 1285, 2808, 1092, 3782, 1115, 1174, 1996, 2603, 3337, 2806, 1105, 2180, 1006, 1900, 3563, 808, 3613, 907, 2069, 1893, 1877, 2362, 2403, 693, 3425, 3501, 626, 244, 3101, 2255, 1661, 1723, 3688, 2571, 2222, 382, 1091, 1, 1884, 3559, 3450, 2648, 2246, 2757, 820, 3375, 1650, 2509, 3760, 2519, 27, 1751, 1964, 1571, 2471, 541, 815, 1094, 2749, 153, 2686, 544, 752, 3085, 1371, 2407, 3675, 1162, 1938, 1197, 3571, 2023, 2847, 1807, 1307, 793, 2610, 1469, 22, 1883, 396, 1098, 1704, 1450, 250, 1258, 3453, 2866, 1995, 2336, 1917, 1724, 3805, 41, 2461, 1241, 2376, 467, 3730, 3090, 3234, 3104, 183, 2827, 274, 1488, 1608, 2495, 3633, 3554, 2723, 3358, 2214, 2963, 3648, 3698, 1569, 3270, 1646, 2675, 2014, 2165, 3106, 2209, 2352, 3580, 3597, 659, 2349, 2074, 988, 1952, 3821, 2640, 3727, 3380, 2646, 3741, 1753, 679, 1707, 633, 1224, 1261, 1501, 1942, 935, 729, 3293, 3638, 2759, 1214, 3028, 3703, 2260, 1406, 2531, 737, 3462, 1495, 1728, 3366, 3510, 42, 3659, 2953, 2378, 1330, 3474, 1372, 89, 1153, 1825, 3218, 3068, 1888, 3287, 2071, 2082, 1460, 3761, 3480, 3424, 651, 1618, 1859, 960, 3344, 3725, 2942, 2176, 1651, 2936, 1187, 2060, 2021, 3317, 861, 1259, 241, 3528, 200, 2392, 1316, 2486, 2923, 923, 98, 3549, 1431, 534, 1840, 3208, 201, 3605, 1337, 877, 571, 3147, 2678, 460, 2970, 3161, 1409, 2304, 1955, 1338, 2364, 754, 3635, 3264, 2620, 889, 566, 744, 1848, 954, 812, 549, 1190, 745, 2371, 2590, 1759, 1710, 1203, 447, 2068, 2691, 245, 880, 122, 3182, 2997, 1934, 1139, 1491, 1166, 1368, 2030, 162, 1829, 766, 3447, 3451, 3386, 2667, 2162, 2096, 3215, 2133, 3620, 743, 1342, 3385, 446, 3557, 3305, 883, 3616, 2130, 1182, 1346, 530, 1732, 1233, 292, 3787, 2467, 3514, 230, 652, 908, 3318, 1213, 3548, 2957, 1552, 2290, 2036, 1102, 3276, 1897, 1886, 2384, 1625, 3143, 1711, 2457, 2832, 2056, 1746, 3361, 2271, 2532, 1981, 1097, 1008, 705, 3671, 2875, 2870, 1760, 582, 3014, 1524, 2682, 712, 916, 461, 474, 3245, 2337, 2077, 2715, 1765, 3409, 2826, 3734, 998, 3813, 233, 1336, 1880, 1703, 2607, 1663, 1476, 380, 3015, 1595, 132, 3737, 125, 2421, 981, 2966, 1961, 991, 3216, 723, 526, 660, 1309, 2529, 3004, 724, 2595, 2573, 354, 3352, 1436, 15, 1184, 3190, 1167, 1758, 81, 3407, 3669, 3141, 3801, 1953, 2898]"
329
   ]
330
  },
331
  {
332
   "cell_type": "code",
333
   "execution_count": null,
334
   "metadata": {},
335
   "outputs": [],
336
   "source": [
337
    "from app import models\n",
338
    "\n",
339
    "# model = models.RETAIN(input_dim=99, hidden_dim=128)"
340
   ]
341
  },
342
  {
343
   "cell_type": "code",
344
   "execution_count": null,
345
   "metadata": {},
346
   "outputs": [],
347
   "source": [
348
    "def extract_backbone_param(ckpt):\n",
349
    "    backbone = {}\n",
350
    "    for k,v in ckpt.items():\n",
351
    "        if \"backbone\" in k:\n",
352
    "            new_k = k.replace(\"backbone.\", \"\")\n",
353
    "            backbone[new_k] = v\n",
354
    "    return backbone"
355
   ]
356
  },
357
  {
358
   "cell_type": "code",
359
   "execution_count": null,
360
   "metadata": {},
361
   "outputs": [],
362
   "source": [
363
    "x = pd.read_pickle(\"datasets/hm/processed_data/x.pkl\")\n",
364
    "y = pd.read_pickle(\"datasets/hm/processed_data/y.pkl\")\n",
365
    "visits_length = pd.read_pickle(\"datasets/hm/processed_data/visits_length.pkl\")\n",
366
    "x = x[val_idx]\n",
367
    "y = y[val_idx]\n",
368
    "visits_length = visits_length[val_idx]\n",
369
    "device = torch.device(\"cpu\")"
370
   ]
371
  },
372
  {
373
   "cell_type": "code",
374
   "execution_count": null,
375
   "metadata": {},
376
   "outputs": [],
377
   "source": [
378
    "outcome_status=[]\n",
379
    "patient=[]\n",
380
    "for i in range(len(visits_length)):\n",
381
    "    outcome_status.append(y[i][visits_length[i]-1][0])\n",
382
    "    patient.append(x[i][visits_length[i]-1].detach().numpy())\n",
383
    "\n",
384
    "outcome_status = torch.tensor(outcome_status)\n",
385
    "patient = torch.tensor(patient)\n",
386
    "# outcome_status = y[:, 0, 0]\n",
387
    "outcome_status = outcome_status.unsqueeze(-1)\n",
388
    "# patient = x[:, 0, :]\n",
389
    "patient = torch.unsqueeze(patient, dim=1)\n",
390
    "patient = patient.float()"
391
   ]
392
  },
393
  {
394
   "cell_type": "code",
395
   "execution_count": null,
396
   "metadata": {},
397
   "outputs": [],
398
   "source": [
399
    "outcome_status.shape, patient.shape"
400
   ]
401
  },
402
  {
403
   "cell_type": "code",
404
   "execution_count": null,
405
   "metadata": {},
406
   "outputs": [],
407
   "source": [
408
    "def remove_outliers(df,columns,n_std):\n",
409
    "    for col in columns:\n",
410
    "        mean = df[col].mean()\n",
411
    "        sd = df[col].std()\n",
412
    "        df = df[abs(df[col]-mean) <= sd*n_std]\n",
413
    "    return df"
414
   ]
415
  },
416
  {
417
   "cell_type": "code",
418
   "execution_count": null,
419
   "metadata": {},
420
   "outputs": [],
421
   "source": [
422
    "n_std = 3\n",
423
    "approach = 'pca' # 'pca' or 'tsne'"
424
   ]
425
  },
426
  {
427
   "cell_type": "markdown",
428
   "metadata": {},
429
   "source": [
430
    "### Multitask Model"
431
   ]
432
  },
433
  {
434
   "cell_type": "code",
435
   "execution_count": null,
436
   "metadata": {},
437
   "outputs": [],
438
   "source": [
439
    "# model = models.RETAIN(input_dim=99, hidden_dim=128)\n",
440
    "hidden_dim=128\n",
441
    "model = models.ConCare(\n",
442
    "    lab_dim=97,\n",
443
    "    demo_dim=2,\n",
444
    "    hidden_dim=hidden_dim,\n",
445
    "    d_model=hidden_dim,\n",
446
    "    MHD_num_head=4,\n",
447
    "    d_ff=4 * hidden_dim,\n",
448
    "    drop=0.0,\n",
449
    ")\n",
450
    "\n",
451
    "multitask_ckpt = torch.load(\"./checkpoints/hm_multitask_concare_ep100_kf10_bs64_hid128_1_seed0.pth\", map_location=torch.device('cpu'))\n",
452
    "multitask_backbone = extract_backbone_param(multitask_ckpt)\n",
453
    "model.load_state_dict(multitask_backbone)\n",
454
    "out = model(patient, device)\n",
455
    "out = torch.squeeze(out)\n",
456
    "out = out.detach().numpy()"
457
   ]
458
  },
459
  {
460
   "cell_type": "code",
461
   "execution_count": null,
462
   "metadata": {},
463
   "outputs": [],
464
   "source": [
465
    "if approach == 'pca':\n",
466
    "    projected = PCA(2).fit_transform(out)\n",
467
    "else:\n",
468
    "    projected = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(out)\n",
469
    "\n",
470
    "concatenated = np.concatenate([projected, outcome_status], axis=1)\n",
471
    "df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])\n",
472
    "df = remove_outliers(df, ['Component 1', 'Component 2'], n_std)\n",
473
    "df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)\n",
474
    "\n",
475
    "sns.scatterplot(data=df, x=\"Component 1\", y=\"Component 2\", hue=\"Outcome\", style=\"Outcome\", palette=[\"C2\", \"C3\"], alpha=0.5)\n",
476
    "plt.savefig(f'multitask_{approach}.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")"
477
   ]
478
  },
479
  {
480
   "cell_type": "markdown",
481
   "metadata": {},
482
   "source": [
483
    "### Outcome Prediction Model"
484
   ]
485
  },
486
  {
487
   "cell_type": "code",
488
   "execution_count": null,
489
   "metadata": {},
490
   "outputs": [],
491
   "source": [
492
    "hidden_dim=128\n",
493
    "model = models.ConCare(\n",
494
    "    lab_dim=97,\n",
495
    "    demo_dim=2,\n",
496
    "    hidden_dim=hidden_dim,\n",
497
    "    d_model=hidden_dim,\n",
498
    "    MHD_num_head=4,\n",
499
    "    d_ff=4 * hidden_dim,\n",
500
    "    drop=0.0,\n",
501
    ")\n",
502
    "\n",
503
    "outcome_ckpt = torch.load(\"./checkpoints/hm_outcome_concare_ep100_kf10_bs64_hid128_1_seed0.pth\", map_location=torch.device('cpu'))\n",
504
    "outcome_backbone = extract_backbone_param(outcome_ckpt)\n",
505
    "model.load_state_dict(outcome_backbone)\n",
506
    "out = model(patient, device)\n",
507
    "out = torch.squeeze(out)\n",
508
    "out = out.detach().numpy()"
509
   ]
510
  },
511
  {
512
   "cell_type": "code",
513
   "execution_count": null,
514
   "metadata": {},
515
   "outputs": [],
516
   "source": [
517
    "if approach == 'pca':\n",
518
    "    projected = PCA(2).fit_transform(out)\n",
519
    "else:\n",
520
    "    projected = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(out)\n",
521
    "\n",
522
    "concatenated = np.concatenate([projected, outcome_status], axis=1)\n",
523
    "df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])\n",
524
    "df = remove_outliers(df, ['Component 1', 'Component 2'], n_std)\n",
525
    "df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)\n",
526
    "\n",
527
    "sns.scatterplot(data=df, x=\"Component 1\", y=\"Component 2\", hue=\"Outcome\", style=\"Outcome\", palette=[\"C2\", \"C3\"], alpha=0.5)\n",
528
    "plt.savefig(f'outcome_{approach}.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")"
529
   ]
530
  },
531
  {
532
   "cell_type": "markdown",
533
   "metadata": {},
534
   "source": [
535
    "### LOS Prediction model"
536
   ]
537
  },
538
  {
539
   "cell_type": "code",
540
   "execution_count": null,
541
   "metadata": {},
542
   "outputs": [],
543
   "source": [
544
    "hidden_dim=64\n",
545
    "model = models.ConCare(\n",
546
    "    lab_dim=97,\n",
547
    "    demo_dim=2,\n",
548
    "    hidden_dim=hidden_dim,\n",
549
    "    d_model=hidden_dim,\n",
550
    "    MHD_num_head=4,\n",
551
    "    d_ff=4 * hidden_dim,\n",
552
    "    drop=0.0,\n",
553
    ")\n",
554
    "\n",
555
    "los_ckpt = torch.load(\"./checkpoints/hm_los_concare_ep100_kf10_bs64_hid64_1_seed0.pth\", map_location=torch.device('cpu'))\n",
556
    "los_backbone = extract_backbone_param(los_ckpt)\n",
557
    "model.load_state_dict(los_backbone)\n",
558
    "out = model(patient, device)\n",
559
    "out = torch.squeeze(out)\n",
560
    "out = out.detach().numpy()"
561
   ]
562
  },
563
  {
564
   "cell_type": "code",
565
   "execution_count": null,
566
   "metadata": {},
567
   "outputs": [],
568
   "source": [
569
    "if approach == 'pca':\n",
570
    "    projected = PCA(2).fit_transform(out)\n",
571
    "else:\n",
572
    "    projected = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(out)\n",
573
    "\n",
574
    "concatenated = np.concatenate([projected, outcome_status], axis=1)\n",
575
    "df = pd.DataFrame(concatenated, columns = ['Component 1', 'Component 2', 'Outcome'])\n",
576
    "df = remove_outliers(df, ['Component 1', 'Component 2'], n_std)\n",
577
    "df['Outcome'].replace({1: 'Dead', 0: 'Alive'}, inplace=True)\n",
578
    "\n",
579
    "sns.scatterplot(data=df, x=\"Component 1\", y=\"Component 2\", hue=\"Outcome\", style=\"Outcome\", palette=[\"C2\", \"C3\"], alpha=0.5)\n",
580
    "plt.savefig(f'los_{approach}.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")"
581
   ]
582
  },
583
  {
584
   "cell_type": "markdown",
585
   "metadata": {},
586
   "source": [
587
    "## Case Study"
588
   ]
589
  },
590
  {
591
   "cell_type": "markdown",
592
   "metadata": {},
593
   "source": [
594
    "### CDSL dataset"
595
   ]
596
  },
597
  {
598
   "cell_type": "code",
599
   "execution_count": null,
600
   "metadata": {},
601
   "outputs": [],
602
   "source": [
603
    "x = pd.read_pickle(\"datasets/hm/processed_data/x.pkl\")\n",
604
    "y = pd.read_pickle(\"datasets/hm/processed_data/y.pkl\")\n",
605
    "visits_length = pd.read_pickle(\"datasets/hm/processed_data/visits_length.pkl\")\n",
606
    "x = x[val_idx]\n",
607
    "y = y[val_idx]\n",
608
    "visits_length = visits_length[val_idx]\n",
609
    "device = torch.device(\"cpu\")"
610
   ]
611
  },
612
  {
613
   "cell_type": "code",
614
   "execution_count": null,
615
   "metadata": {},
616
   "outputs": [],
617
   "source": [
618
    "long_visits_id_list = []\n",
619
    "\n",
620
    "for i in range(len(visits_length)):\n",
621
    "    if visits_length[i] > 20:\n",
622
    "        long_visits_id_list.append(i)\n",
623
    "        print(f\"[{i}: {y[i][0][0].item()}]\", end=\"    \")"
624
   ]
625
  },
626
  {
627
   "cell_type": "markdown",
628
   "metadata": {},
629
   "source": [
630
    "#### ConCare Multitask"
631
   ]
632
  },
633
  {
634
   "cell_type": "code",
635
   "execution_count": null,
636
   "metadata": {},
637
   "outputs": [],
638
   "source": [
639
    "idx=20 # 60, 243 | 24, 114   129 20\n",
640
    "outcome_status=y[idx][0][0]\n",
641
    "los_status=y[idx][:visits_length[idx]][:,1]\n",
642
    "patient=x[idx][:visits_length[idx]]\n",
643
    "\n",
644
    "outcome_status = outcome_status.unsqueeze(-1)\n",
645
    "patient = patient.float()\n",
646
    "\n",
647
    "hidden_dim=128\n",
648
    "backbone = models.ConCare(\n",
649
    "    lab_dim=97,\n",
650
    "    demo_dim=2,\n",
651
    "    hidden_dim=hidden_dim,\n",
652
    "    d_model=hidden_dim,\n",
653
    "    MHD_num_head=4,\n",
654
    "    d_ff=4 * hidden_dim,\n",
655
    "    drop=0.0,\n",
656
    ")\n",
657
    "head = models.MultitaskHead(\n",
658
    "    hidden_dim=hidden_dim,\n",
659
    "    output_dim=1,\n",
660
    ")\n",
661
    "model = models.Model(backbone, head)\n",
662
    "los_ckpt = torch.load(\"./checkpoints/hm_multitask_concare_ep100_kf10_bs64_hid128_1_seed0.pth\", map_location=torch.device('cpu'))\n",
663
    "model.load_state_dict(los_ckpt)\n",
664
    "\n",
665
    "# ConCare model does not accept single patient input, so we need to create a batch of size 2\n",
666
    "patient = torch.stack((patient, patient), dim=0)\n",
667
    "risk, out = model(patient, device, None)\n",
668
    "out = out[0]\n",
669
    "risk = risk[0]\n",
670
    "los_statistics = {'los_mean': 6.1315513, 'los_std': 5.6816683}\n",
671
    "out = torch.squeeze(out)\n",
672
    "risk = torch.squeeze(risk)\n",
673
    "out = out * los_statistics['los_std'] + los_statistics['los_mean']\n",
674
    "\n",
675
    "los_status, out, outcome_status, risk"
676
   ]
677
  },
678
  {
679
   "cell_type": "code",
680
   "execution_count": null,
681
   "metadata": {},
682
   "outputs": [],
683
   "source": [
684
    "hidden_dim=128\n",
685
    "backbone = models.ConCare(\n",
686
    "    lab_dim=97,\n",
687
    "    demo_dim=2,\n",
688
    "    hidden_dim=hidden_dim,\n",
689
    "    d_model=hidden_dim,\n",
690
    "    MHD_num_head=4,\n",
691
    "    d_ff=4 * hidden_dim,\n",
692
    "    drop=0.0,\n",
693
    ")\n",
694
    "head = models.MultitaskHead(\n",
695
    "    hidden_dim=hidden_dim,\n",
696
    "    output_dim=1,\n",
697
    ")\n",
698
    "model = models.Model(backbone, head)\n",
699
    "los_ckpt = torch.load(\"./checkpoints/hm_multitask_concare_ep100_kf10_bs64_hid128_1_seed0.pth\", map_location=torch.device('cpu'))\n",
700
    "model.load_state_dict(los_ckpt)\n",
701
    "\n",
702
    "def inference_case(idx):\n",
703
    "    outcome_status=y[idx][0][0]\n",
704
    "    los_status=y[idx][:visits_length[idx]][:,1]\n",
705
    "    patient=x[idx][:visits_length[idx]]\n",
706
    "\n",
707
    "    outcome_status = outcome_status.unsqueeze(-1)\n",
708
    "    patient = patient.float()\n",
709
    "\n",
710
    "    # ConCare model does not accept single patient input, so we need to create a batch of size 2\n",
711
    "    patient = torch.stack((patient, patient), dim=0)\n",
712
    "    risk, out = model(patient, device, None)\n",
713
    "    out = out[0]\n",
714
    "    risk = risk[0]\n",
715
    "    los_statistics = {'los_mean': 6.1315513, 'los_std': 5.6816683}\n",
716
    "    out = torch.squeeze(out)\n",
717
    "    risk = torch.squeeze(risk)\n",
718
    "    out = out * los_statistics['los_std'] + los_statistics['los_mean']\n",
719
    "    # print(\"los gt:\", los_status)\n",
720
    "    # print(\"los pred:\", out)\n",
721
    "    # print(\"outcome:\", outcome_status)\n",
722
    "    # print(\"risk pred:\", risk)\n",
723
    "    # print(\"--------------------\")\n",
724
    "    return los_status.cpu().detach().numpy(), out.cpu().detach().numpy(), outcome_status.cpu().detach().numpy(), risk.cpu().detach().numpy()\n",
725
    "\n",
726
    "\n",
727
    "los_status, out, outcome_status, risk = inference_case(60)\n",
728
    "los_status, out, outcome_status, risk"
729
   ]
730
  },
731
  {
732
   "cell_type": "code",
733
   "execution_count": null,
734
   "metadata": {},
735
   "outputs": [],
736
   "source": [
737
    "def plot_case(los_status, out, outcome_status, risk, idx):\n",
738
    "    color = 'green' if outcome_status == 0 else 'red'\n",
739
    "    label_info = f'Alive Case #{idx}' if outcome_status == 0 else f'Dead Case #{idx}'\n",
740
    "    filename = f'case_study_los_alive_{idx}' if outcome_status == 0 else f'case_study_los_dead_{idx}'\n",
741
    "    los_status = np.negative(los_status)\n",
742
    "    fig = plt.figure()\n",
743
    "    ax = fig.add_subplot(111)\n",
744
    "    ax2 = ax.twinx()\n",
745
    "    ax.plot(los_status, out, marker='o', linewidth=1, label=label_info, color=color)\n",
746
    "    ax2.plot(los_status, risk, marker=',', linewidth=2, label='Risk', color='skyblue')\n",
747
    "    ax.set_ylim([0, 30])\n",
748
    "    ax2.set_ylim([0, 1])\n",
749
    "    ax2.grid(False)\n",
750
    "    ax.plot([0, -30], [0, 30], linestyle='-.', color='grey')\n",
751
    "    ax.set_xlabel('True Length of Stay (days)')\n",
752
    "    ax.set_ylabel('Predicted Length of Stay (days)')\n",
753
    "    ax.legend(loc='upper left')\n",
754
    "    ax2.legend(loc='lower left')\n",
755
    "\n",
756
    "    # plt.savefig(f'case_study_los_{idx}.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
757
    "    plt.savefig(f'cases/cdsl/{filename}.png', format=\"png\", bbox_inches=\"tight\")\n",
758
    "    plt.show()\n"
759
   ]
760
  },
761
  {
762
   "cell_type": "code",
763
   "execution_count": null,
764
   "metadata": {},
765
   "outputs": [],
766
   "source": [
767
    "def abs_var(x, position):\n",
768
    "    return f'{abs(int(x))}'\n",
769
    "\n",
770
    "# case 1\n",
771
    "los_status, out, outcome_status, risk = inference_case(60)\n",
772
    "los_status = np.negative(los_status)\n",
773
    "fig, (ax, bx) = plt.subplots(1, 2, figsize=(12, 5))\n",
774
    "ax2 = ax.twinx()\n",
775
    "ax.plot(los_status, out, marker='o', linewidth=1, label=\"Alive case\", color='green')\n",
776
    "ax2.plot(los_status, risk, marker=',', linewidth=2, label='Risk', color='skyblue')\n",
777
    "ax.set_ylim([0, 22])\n",
778
    "ax2.set_ylim([0, 1])\n",
779
    "ax2.grid(False)\n",
780
    "ax.plot([0, -22], [0, 22], linestyle='-.', color='grey')\n",
781
    "ax.set_xlabel('True Length of Stay (days)')\n",
782
    "ax.set_ylabel('Predicted Length of Stay (days)')\n",
783
    "ax2.set_ylabel('Risk')\n",
784
    "\n",
785
    "fig.gca().xaxis.set_major_formatter(FuncFormatter(abs_var))\n",
786
    "\n",
787
    "# case 2\n",
788
    "los_status, out, outcome_status, risk = inference_case(169)\n",
789
    "los_status = np.negative(los_status)\n",
790
    "\n",
791
    "bx2 = bx.twinx()\n",
792
    "bx.plot(los_status, out, marker='o', linewidth=1, label='Dead case', color='red')\n",
793
    "bx2.plot(los_status, risk, marker=',', linewidth=2, color='skyblue')\n",
794
    "bx.set_ylim([0, 22])\n",
795
    "bx2.set_ylim([0, 1])\n",
796
    "bx2.grid(False)\n",
797
    "bx.plot([0, -22], [0, 22], linestyle='-.', color='grey')\n",
798
    "bx.set_xlabel('True Length of Stay (days)')\n",
799
    "bx.set_ylabel('Predicted Length of Stay (days)')\n",
800
    "bx2.set_ylabel('Risk')\n",
801
    "\n",
802
    "fig.gca().xaxis.set_major_formatter(FuncFormatter(abs_var))\n",
803
    "\n",
804
    "fig.legend(bbox_to_anchor=(0.5, 1.05), loc=\"upper center\", ncol=3)\n",
805
    "fig.tight_layout()\n",
806
    "plt.savefig('cases/cdsl_case_study.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
807
    "plt.show()"
808
   ]
809
  },
810
  {
811
   "cell_type": "code",
812
   "execution_count": null,
813
   "metadata": {},
814
   "outputs": [],
815
   "source": [
816
    "# for idx in long_visits_id_list:\n",
817
    "#     los_status, out, outcome_status, risk = inference_case(idx)\n",
818
    "#     plot_case(los_status, out, outcome_status, risk, idx)"
819
   ]
820
  },
821
  {
822
   "cell_type": "markdown",
823
   "metadata": {},
824
   "source": [
825
    "### TJH dataset"
826
   ]
827
  },
828
  {
829
   "cell_type": "code",
830
   "execution_count": null,
831
   "metadata": {},
832
   "outputs": [],
833
   "source": [
834
    "val_idx = [287, 116, 186, 292, 225, 290, 277, 311, 20, 71, 52, 304, 87, 74, 318, 92, 121, 236, 226, 149, 295, 103, 14, 305, 213, 165, 174, 106, 99, 102, 151, 177, 233, 130, 1, 270]"
835
   ]
836
  },
837
  {
838
   "cell_type": "code",
839
   "execution_count": null,
840
   "metadata": {},
841
   "outputs": [],
842
   "source": [
843
    "x = pd.read_pickle(\"datasets/tongji/processed_data/x.pkl\")\n",
844
    "y = pd.read_pickle(\"datasets/tongji/processed_data/y.pkl\")\n",
845
    "visits_length = pd.read_pickle(\"datasets/tongji/processed_data/visits_length.pkl\")\n",
846
    "x = x[val_idx]\n",
847
    "y = y[val_idx]\n",
848
    "visits_length = visits_length[val_idx]\n",
849
    "device = torch.device(\"cpu\")"
850
   ]
851
  },
852
  {
853
   "cell_type": "code",
854
   "execution_count": null,
855
   "metadata": {},
856
   "outputs": [],
857
   "source": [
858
    "long_visits_id_list = []\n",
859
    "\n",
860
    "for i in range(len(visits_length)):\n",
861
    "    if visits_length[i] > 5:\n",
862
    "        long_visits_id_list.append(i)\n",
863
    "        print(f\"[{i}: {y[i][0][0].item()}] len:{visits_length[i]}\", end=\"    \")"
864
   ]
865
  },
866
  {
867
   "cell_type": "markdown",
868
   "metadata": {},
869
   "source": [
870
    "#### RETAIN multitask"
871
   ]
872
  },
873
  {
874
   "cell_type": "code",
875
   "execution_count": null,
876
   "metadata": {},
877
   "outputs": [],
878
   "source": [
879
    "hidden_dim=64\n",
880
    "backbone = models.RETAIN(\n",
881
    "    input_dim=75,\n",
882
    "    hidden_dim=hidden_dim,\n",
883
    "    dropout=0.0,\n",
884
    ")\n",
885
    "head = models.MultitaskHead(\n",
886
    "    hidden_dim=hidden_dim,\n",
887
    "    output_dim=1,\n",
888
    ")\n",
889
    "model = models.Model(backbone, head)\n",
890
    "los_ckpt = torch.load(\"./checkpoints/tj_multitask_retain_ep100_kf10_bs64_hid64_1_seed42.pth\", map_location=torch.device('cpu'))\n",
891
    "model.load_state_dict(los_ckpt)\n",
892
    "\n",
893
    "def inference_case(idx):\n",
894
    "    outcome_status=y[idx][0][0]\n",
895
    "    los_status=y[idx][:visits_length[idx]][:,1]\n",
896
    "    patient=x[idx][:visits_length[idx]]\n",
897
    "\n",
898
    "    outcome_status = outcome_status.unsqueeze(-1)\n",
899
    "    patient = patient.float()\n",
900
    "\n",
901
    "    patient = torch.stack((patient, patient), dim=0)\n",
902
    "    risk, out = model(patient, device, None)\n",
903
    "    out = out[0]\n",
904
    "    risk = risk[0]\n",
905
    "    los_statistics = {'los_mean': 7.7147756, 'los_std': 7.1851807}\n",
906
    "    out = torch.squeeze(out)\n",
907
    "    risk = torch.squeeze(risk)\n",
908
    "    out = out * los_statistics['los_std'] + los_statistics['los_mean']\n",
909
    "    # print(\"los gt:\", los_status)\n",
910
    "    # print(\"los pred:\", out)\n",
911
    "    # print(\"outcome:\", outcome_status)\n",
912
    "    # print(\"risk pred:\", risk)\n",
913
    "    # print(\"--------------------\")\n",
914
    "    return los_status.cpu().detach().numpy(), out.cpu().detach().numpy(), outcome_status.cpu().detach().numpy(), risk.cpu().detach().numpy()\n",
915
    "\n",
916
    "\n",
917
    "los_status, out, outcome_status, risk = inference_case(30)\n",
918
    "los_status, out, outcome_status, risk"
919
   ]
920
  },
921
  {
922
   "cell_type": "code",
923
   "execution_count": null,
924
   "metadata": {},
925
   "outputs": [],
926
   "source": [
927
    "def plot_case(los_status, out, outcome_status, risk, idx):\n",
928
    "    color = 'green' if outcome_status == 0 else 'red'\n",
929
    "    label_info = f'Alive Case #{idx}' if outcome_status == 0 else f'Dead Case #{idx}'\n",
930
    "    filename = f'case_study_los_alive_{idx}' if outcome_status == 0 else f'case_study_los_dead_{idx}'\n",
931
    "    los_status = np.negative(los_status)\n",
932
    "    fig = plt.figure()\n",
933
    "    ax = fig.add_subplot(111)\n",
934
    "    ax2 = ax.twinx()\n",
935
    "    ax.plot(los_status, out, marker='o', linewidth=1, label=label_info, color=color)\n",
936
    "    ax2.plot(los_status, risk, marker=',', linewidth=1, label='Risk', color='skyblue')\n",
937
    "    ax.set_ylim([0, 30])\n",
938
    "    ax2.set_ylim([0, 1])\n",
939
    "    ax2.grid(False)\n",
940
    "    ax.plot([0, -30], [0, 30], linestyle='-.', color='grey')\n",
941
    "    ax.set_xlabel('True Length of Stay (days)')\n",
942
    "    ax.set_ylabel('Predicted Length of Stay (days)')\n",
943
    "    ax.legend(loc='upper left')\n",
944
    "    ax2.legend(loc='lower left')\n",
945
    "\n",
946
    "    # plt.savefig(f'case_study_los_{idx}.pdf', dpi=500, format=\"pdf\", bbox_inches=\"tight\")\n",
947
    "    plt.savefig(f'cases/tjh/{filename}.png', format=\"png\", bbox_inches=\"tight\")\n",
948
    "    plt.show()\n"
949
   ]
950
  },
951
  {
952
   "cell_type": "code",
953
   "execution_count": null,
954
   "metadata": {},
955
   "outputs": [],
956
   "source": [
957
    "# for idx in long_visits_id_list:\n",
958
    "#     los_status, out, outcome_status, risk = inference_case(idx)\n",
959
    "#     plot_case(los_status, out, outcome_status, risk, idx)"
960
   ]
961
  }
962
 ],
963
 "metadata": {
964
  "kernelspec": {
965
   "display_name": "Python 3.9.5 ('pytorch')",
966
   "language": "python",
967
   "name": "python3"
968
  },
969
  "language_info": {
970
   "codemirror_mode": {
971
    "name": "ipython",
972
    "version": 3
973
   },
974
   "file_extension": ".py",
975
   "mimetype": "text/x-python",
976
   "name": "python",
977
   "nbconvert_exporter": "python",
978
   "pygments_lexer": "ipython3",
979
   "version": "3.9.5"
980
  },
981
  "orig_nbformat": 4,
982
  "vscode": {
983
   "interpreter": {
984
    "hash": "e382889b16d65b8f9d2caeea05d88db6d501b8794eac9af8ee0956d5affe33e5"
985
   }
986
  }
987
 },
988
 "nbformat": 4,
989
 "nbformat_minor": 2
990
}