a b/datasets/tjh/preprocess.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {
7
    "slideshow": {
8
     "slide_type": "-"
9
    }
10
   },
11
   "outputs": [],
12
   "source": [
13
    "# Import necessary packages\n",
14
    "import numpy as np\n",
15
    "import pandas as pd\n",
16
    "import torch"
17
   ]
18
  },
19
  {
20
   "cell_type": "code",
21
   "execution_count": null,
22
   "metadata": {},
23
   "outputs": [],
24
   "source": [
25
    "# Read raw data\n",
26
    "df_train: pd.DataFrame = pd.read_excel('./raw_data/time_series_375_prerpocess_en.xlsx')"
27
   ]
28
  },
29
  {
30
   "cell_type": "markdown",
31
   "metadata": {},
32
   "source": [
33
    "Steps:\n",
34
    "\n",
35
    "- fill `patient_id`\n",
36
    "- only reserve y-m-d for `RE_DATE` column\n",
37
    "- merge lab tests of the same (patient_id, date)\n",
38
    "- calculate and save features' statistics information (demographic and lab test data are calculated separately)\n",
39
    "- normalize data\n",
40
    "- feature selection\n",
41
    "- fill missing data (our filling strategy will be described below)\n",
42
    "- combine above data to time series data (one patient one record)\n",
43
    "- export to python pickle file"
44
   ]
45
  },
46
  {
47
   "cell_type": "code",
48
   "execution_count": null,
49
   "metadata": {},
50
   "outputs": [],
51
   "source": [
52
    "# fill `patient_id` rows\n",
53
    "df_train['PATIENT_ID'].fillna(method='ffill', inplace=True)\n",
54
    "\n",
55
    "# gender transformation: 1--male, 0--female\n",
56
    "df_train['gender'].replace(2, 0, inplace=True)\n",
57
    "\n",
58
    "# only reserve y-m-d for `RE_DATE` and `Discharge time` columns\n",
59
    "df_train['RE_DATE'] = df_train['RE_DATE'].dt.strftime('%Y-%m-%d')\n",
60
    "df_train['Discharge time'] = df_train['Discharge time'].dt.strftime('%Y-%m-%d')\n"
61
   ]
62
  },
63
  {
64
   "cell_type": "code",
65
   "execution_count": null,
66
   "metadata": {},
67
   "outputs": [],
68
   "source": [
69
    "df_train = df_train.dropna(subset = ['PATIENT_ID', 'RE_DATE', 'Discharge time'], how='any')"
70
   ]
71
  },
72
  {
73
   "cell_type": "code",
74
   "execution_count": null,
75
   "metadata": {},
76
   "outputs": [],
77
   "source": [
78
    "# calculate raw data's los interval\n",
79
    "df_grouped = df_train.groupby('PATIENT_ID')\n",
80
    "\n",
81
    "los_interval_list = []\n",
82
    "los_interval_alive_list = []\n",
83
    "los_interval_dead_list = []\n",
84
    "\n",
85
    "for name, group in df_grouped:\n",
86
    "    sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
87
    "    # print(sorted_group['outcome'])\n",
88
    "    # print('---')\n",
89
    "    # print(type(sorted_group))\n",
90
    "    intervals = sorted_group['RE_DATE'].tolist()\n",
91
    "    outcome = sorted_group['outcome'].tolist()[0]\n",
92
    "    cur_visits_len = len(intervals)\n",
93
    "    # print(cur_visits_len)\n",
94
    "    if cur_visits_len == 1:\n",
95
    "        continue\n",
96
    "    for i in range(1, len(intervals)):\n",
97
    "        los_interval_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
98
    "        if outcome == 0:\n",
99
    "            los_interval_alive_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
100
    "        else:\n",
101
    "            los_interval_dead_list.append((pd.to_datetime(intervals[i])-pd.to_datetime(intervals[i-1])).days)\n",
102
    "\n",
103
    "los_interval_list = np.array(los_interval_list)\n",
104
    "los_interval_alive_list = np.array(los_interval_alive_list)\n",
105
    "los_interval_dead_list = np.array(los_interval_dead_list)\n",
106
    "\n",
107
    "output = {\n",
108
    "    'overall': los_interval_list,\n",
109
    "    'alive': los_interval_alive_list,\n",
110
    "    'dead': los_interval_dead_list,\n",
111
    "}\n",
112
    "# pd.to_pickle(output, 'raw_tjh_los_interval_list.pkl')\n"
113
   ]
114
  },
115
  {
116
   "cell_type": "code",
117
   "execution_count": null,
118
   "metadata": {},
119
   "outputs": [],
120
   "source": [
121
    "# we have 2 types of prediction tasks: 1) predict mortality outcome, 2) length of stay\n",
122
    "\n",
123
    "# below are all lab test features\n",
124
    "labtest_features_str = \"\"\"\n",
125
    "Hypersensitive cardiac troponinI\themoglobin\tSerum chloride\tProthrombin time\tprocalcitonin\teosinophils(%)\tInterleukin 2 receptor\tAlkaline phosphatase\talbumin\tbasophil(%)\tInterleukin 10\tTotal bilirubin\tPlatelet count\tmonocytes(%)\tantithrombin\tInterleukin 8\tindirect bilirubin\tRed blood cell distribution width \tneutrophils(%)\ttotal protein\tQuantification of Treponema pallidum antibodies\tProthrombin activity\tHBsAg\tmean corpuscular volume\thematocrit\tWhite blood cell count\tTumor necrosis factorα\tmean corpuscular hemoglobin concentration\tfibrinogen\tInterleukin 1β\tUrea\tlymphocyte count\tPH value\tRed blood cell count\tEosinophil count\tCorrected calcium\tSerum potassium\tglucose\tneutrophils count\tDirect bilirubin\tMean platelet volume\tferritin\tRBC distribution width SD\tThrombin time\t(%)lymphocyte\tHCV antibody quantification\tD-D dimer\tTotal cholesterol\taspartate aminotransferase\tUric acid\tHCO3-\tcalcium\tAmino-terminal brain natriuretic peptide precursor(NT-proBNP)\tLactate dehydrogenase\tplatelet large cell ratio \tInterleukin 6\tFibrin degradation products\tmonocytes count\tPLT distribution width\tglobulin\tγ-glutamyl transpeptidase\tInternational standard ratio\tbasophil count(#)\t2019-nCoV nucleic acid detection\tmean corpuscular hemoglobin \tActivation of partial thromboplastin time\tHypersensitive c-reactive protein\tHIV antibody quantification\tserum sodium\tthrombocytocrit\tESR\tglutamic-pyruvic transaminase\teGFR\tcreatinine\n",
126
    "\"\"\"\n",
127
    "\n",
128
    "# below are 2 demographic features\n",
129
    "demographic_features_str = \"\"\"\n",
130
    "age\tgender\n",
131
    "\"\"\"\n",
132
    "\n",
133
    "labtest_features = [f for f in labtest_features_str.strip().split('\\t')]\n",
134
    "demographic_features = [f for f in demographic_features_str.strip().split('\\t')]\n",
135
    "target_features = ['outcome', 'LOS']\n",
136
    "\n",
137
    "# from our observation, `2019-nCoV nucleic acid detection` feature (in lab test) are all -1 value\n",
138
    "# so we remove this feature here\n",
139
    "labtest_features.remove('2019-nCoV nucleic acid detection')"
140
   ]
141
  },
142
  {
143
   "cell_type": "code",
144
   "execution_count": null,
145
   "metadata": {},
146
   "outputs": [],
147
   "source": [
148
    "# if some values are negative, set it as Null\n",
149
    "df_train[df_train[demographic_features + labtest_features]<0] = np.nan"
150
   ]
151
  },
152
  {
153
   "cell_type": "code",
154
   "execution_count": null,
155
   "metadata": {},
156
   "outputs": [],
157
   "source": [
158
    "# merge lab tests of the same (patient_id, date)\n",
159
    "df_train = df_train.groupby(['PATIENT_ID', 'RE_DATE', 'Discharge time'], dropna=True, as_index = False).mean()"
160
   ]
161
  },
162
  {
163
   "cell_type": "code",
164
   "execution_count": null,
165
   "metadata": {},
166
   "outputs": [],
167
   "source": [
168
    "# calculate length-of-stay lable\n",
169
    "df_train['LOS'] = (pd.to_datetime(df_train['Discharge time']) - pd.to_datetime(df_train['RE_DATE'])).dt.days"
170
   ]
171
  },
172
  {
173
   "cell_type": "code",
174
   "execution_count": null,
175
   "metadata": {},
176
   "outputs": [],
177
   "source": [
178
    "# if los values are negative, set it as 0\n",
179
    "df_train['LOS'] = df_train['LOS'].clip(lower=0)"
180
   ]
181
  },
182
  {
183
   "cell_type": "code",
184
   "execution_count": null,
185
   "metadata": {},
186
   "outputs": [],
187
   "source": [
188
    "# save features' statistics information\n",
189
    "\n",
190
    "def calculate_statistic_info(df, features):\n",
191
    "    \"\"\"all values calculated\"\"\"\n",
192
    "    statistic_info = {}\n",
193
    "    len_df = len(df)\n",
194
    "    for _, e in enumerate(features):\n",
195
    "        h = {}\n",
196
    "        h['count'] = int(df[e].count())\n",
197
    "        h['missing'] = str(round(float((100-df[e].count()*100/len_df)),3))+\"%\"\n",
198
    "        h['mean'] = float(df[e].mean())\n",
199
    "        h['max'] = float(df[e].max())\n",
200
    "        h['min'] = float(df[e].min())\n",
201
    "        h['median'] = float(df[e].median())\n",
202
    "        h['std'] = float(df[e].std())\n",
203
    "        statistic_info[e] = h\n",
204
    "    return statistic_info\n",
205
    "\n",
206
    "def calculate_middle_part_statistic_info(df, features):\n",
207
    "    \"\"\"calculate 5% ~ 95% percentile data\"\"\"\n",
208
    "    statistic_info = {}\n",
209
    "    len_df = len(df)\n",
210
    "    # calculate 5% and 95% percentile of dataframe\n",
211
    "    middle_part_df_info = df.quantile([.05, .95])\n",
212
    "\n",
213
    "    for _, e in enumerate(features):\n",
214
    "        low_value = middle_part_df_info[e][.05]\n",
215
    "        high_value = middle_part_df_info[e][.95]\n",
216
    "        middle_part_df_element = df.loc[(df[e] >= low_value) & (df[e] <= high_value)][e]\n",
217
    "        h = {}\n",
218
    "        h['count'] = int(middle_part_df_element.count())\n",
219
    "        h['missing'] = str(round(float((100-middle_part_df_element.count()*100/len_df)),3))+\"%\"\n",
220
    "        h['mean'] = float(middle_part_df_element.mean())\n",
221
    "        h['max'] = float(middle_part_df_element.max())\n",
222
    "        h['min'] = float(middle_part_df_element.min())\n",
223
    "        h['median'] = float(middle_part_df_element.median())\n",
224
    "        h['std'] = float(middle_part_df_element.std())\n",
225
    "        statistic_info[e] = h\n",
226
    "    return statistic_info\n",
227
    "\n",
228
    "# labtest_statistic_info = calculate_statistic_info(df_train, labtest_features)\n",
229
    "\n",
230
    "\n",
231
    "# group by patient_id, then calculate lab test/demographic features' statistics information\n",
232
    "groupby_patientid_df = df_train.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
233
    "\n",
234
    "\n",
235
    "# calculate statistic info (all values calculated)\n",
236
    "labtest_patientwise_statistic_info = calculate_statistic_info(groupby_patientid_df, labtest_features)\n",
237
    "demographic_statistic_info = calculate_statistic_info(groupby_patientid_df, demographic_features) # it's also patient-wise\n",
238
    "\n",
239
    "# calculate statistic info (5% ~ 95% only)\n",
240
    "demographic_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, demographic_features) \n",
241
    "labtest_patientwise_statistic_info_2 = calculate_middle_part_statistic_info(groupby_patientid_df, labtest_features) \n",
242
    "\n",
243
    "# take 2 statistics information's union\n",
244
    "statistic_info = labtest_patientwise_statistic_info_2 | demographic_statistic_info_2\n"
245
   ]
246
  },
247
  {
248
   "cell_type": "code",
249
   "execution_count": null,
250
   "metadata": {},
251
   "outputs": [],
252
   "source": [
253
    "# observe features, export to csv file [optional]\n",
254
    "to_export_dict = {'name': [], 'missing_rate': [], 'count': [], 'mean': [], 'max': [], 'min': [], 'median': [], 'std': []}\n",
255
    "for key in statistic_info:\n",
256
    "    detail = statistic_info[key]\n",
257
    "    to_export_dict['name'].append(key)\n",
258
    "    to_export_dict['count'].append(detail['count'])\n",
259
    "    to_export_dict['missing_rate'].append(detail['missing'])\n",
260
    "    to_export_dict['mean'].append(detail['mean'])\n",
261
    "    to_export_dict['max'].append(detail['max'])\n",
262
    "    to_export_dict['min'].append(detail['min'])\n",
263
    "    to_export_dict['median'].append(detail['median'])\n",
264
    "    to_export_dict['std'].append(detail['std'])\n",
265
    "to_export_df = pd.DataFrame.from_dict(to_export_dict)\n",
266
    "# to_export_df.to_csv('statistic_info.csv')"
267
   ]
268
  },
269
  {
270
   "cell_type": "code",
271
   "execution_count": null,
272
   "metadata": {},
273
   "outputs": [],
274
   "source": [
275
    "# normalize data\n",
276
    "def normalize_data(df, features, statistic_info):\n",
277
    "    \n",
278
    "    df_features = df[features]\n",
279
    "    df_features = df_features.apply(lambda x: (x - statistic_info[x.name]['mean']) / (statistic_info[x.name]['std']+1e-12))\n",
280
    "    df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
281
    "    return df\n",
282
    "df_train = normalize_data(df_train, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized"
283
   ]
284
  },
285
  {
286
   "cell_type": "code",
287
   "execution_count": null,
288
   "metadata": {},
289
   "outputs": [],
290
   "source": [
291
    "# filter outliers\n",
292
    "def filter_data(df, features, bar=3):\n",
293
    "    for f in features:\n",
294
    "        df[f] = df[f].mask(df[f].abs().gt(bar))\n",
295
    "    return df\n",
296
    "df_train = filter_data(df_train, demographic_features + labtest_features, bar=3)"
297
   ]
298
  },
299
  {
300
   "cell_type": "code",
301
   "execution_count": null,
302
   "metadata": {},
303
   "outputs": [],
304
   "source": [
305
    "# drop rows if all labtest_features are recorded nan\n",
306
    "df_train = df_train.dropna(subset = labtest_features, how='all')"
307
   ]
308
  },
309
  {
310
   "cell_type": "code",
311
   "execution_count": null,
312
   "metadata": {},
313
   "outputs": [],
314
   "source": [
315
    "# Calculate data statistics after preprocessing steps (before imputation)\n",
316
    "\n",
317
    "# Step 1: reverse z-score normalization operation\n",
318
    "df_reverse = df_train\n",
319
    "# reverse normalize data\n",
320
    "def reverse_normalize_data(df, features, statistic_info):\n",
321
    "    df_features = df[features]\n",
322
    "    df_features = df_features.apply(lambda x: x * (statistic_info[x.name]['std']+1e-12) + statistic_info[x.name]['mean'])\n",
323
    "    df = pd.concat([df[['PATIENT_ID', 'gender', 'RE_DATE', 'outcome', 'LOS']], df_features], axis=1)\n",
324
    "    return df\n",
325
    "df_reverse = reverse_normalize_data(df_reverse, ['age'] + labtest_features, statistic_info) # gender don't need to be normalized\n",
326
    "\n",
327
    "statistics = {}\n",
328
    "\n",
329
    "for f in demographic_features+labtest_features:\n",
330
    "    statistics[f]={}\n",
331
    "\n",
332
    "def calculate_quantile_statistic_info(df, features, case):\n",
333
    "    \"\"\"all values calculated\"\"\"\n",
334
    "    for _, e in enumerate(features):\n",
335
    "        # print(e, lo, mi, hi)\n",
336
    "        if e == 'gender':\n",
337
    "            unique, count=np.unique(df[e],return_counts=True)\n",
338
    "            data_count=dict(zip(unique,count)) # key = 1 male, 0 female\n",
339
    "            print(data_count)\n",
340
    "            male_percentage = data_count[1.0]*100/(data_count[1.0]+data_count[0.0])\n",
341
    "            statistics[e][case] = f\"{male_percentage:.2f}% Male\"\n",
342
    "            print(statistics[e][case])\n",
343
    "        else:\n",
344
    "            lo = round(np.nanpercentile(df[e], 25), 2)\n",
345
    "            mi = round(np.nanpercentile(df[e], 50), 2)\n",
346
    "            hi = round(np.nanpercentile(df[e], 75), 2)\n",
347
    "            statistics[e][case] = f\"{mi:.2f} [{lo:.2f}, {hi:.2f}]\"\n",
348
    "\n",
349
    "def calculate_missing_rate(df, features, case='missing_rate'):\n",
350
    "    for _, e in enumerate(features):\n",
351
    "        missing_rate = round(float(df[e].isnull().sum()*100/df[e].shape[0]), 2)\n",
352
    "        statistics[e][case] = f\"{missing_rate:.2f}%\"\n",
353
    "\n",
354
    "tmp_groupby_pid = df_reverse.groupby(['PATIENT_ID'], dropna=True, as_index = False).mean()\n",
355
    "\n",
356
    "calculate_quantile_statistic_info(tmp_groupby_pid, demographic_features, 'overall')\n",
357
    "calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==0], demographic_features, 'alive')\n",
358
    "calculate_quantile_statistic_info(tmp_groupby_pid[tmp_groupby_pid['outcome']==1], demographic_features, 'dead')\n",
359
    "\n",
360
    "calculate_quantile_statistic_info(df_reverse, labtest_features, 'overall')\n",
361
    "calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==0], labtest_features, 'alive')\n",
362
    "calculate_quantile_statistic_info(df_reverse[df_reverse['outcome']==1], labtest_features, 'dead')\n",
363
    "\n",
364
    "calculate_missing_rate(df_reverse, demographic_features+labtest_features, 'missing_rate')\n",
365
    "\n",
366
    "export_quantile_statistics = {'Characteristics':[], 'Overall':[], 'Alive':[], 'Dead':[], 'Missing Rate':[]}\n",
367
    "for f in demographic_features+labtest_features:\n",
368
    "    export_quantile_statistics['Characteristics'].append(f)\n",
369
    "    export_quantile_statistics['Overall'].append(statistics[f]['overall'])\n",
370
    "    export_quantile_statistics['Alive'].append(statistics[f]['alive'])\n",
371
    "    export_quantile_statistics['Dead'].append(statistics[f]['dead'])\n",
372
    "    export_quantile_statistics['Missing Rate'].append(statistics[f]['missing_rate'])\n",
373
    "\n",
374
    "# pd.DataFrame.from_dict(export_quantile_statistics).to_csv('statistics.csv')"
375
   ]
376
  },
377
  {
378
   "cell_type": "code",
379
   "execution_count": null,
380
   "metadata": {},
381
   "outputs": [],
382
   "source": [
383
    "def calculate_data_existing_length(data):\n",
384
    "    res = 0\n",
385
    "    for i in data:\n",
386
    "        if not pd.isna(i):\n",
387
    "            res += 1\n",
388
    "    return res\n",
389
    "# elements in data are sorted in time ascending order\n",
390
    "def fill_missing_value(data, to_fill_value=0):\n",
391
    "    data_len = len(data)\n",
392
    "    data_exist_len = calculate_data_existing_length(data)\n",
393
    "    if data_len == data_exist_len:\n",
394
    "        return data\n",
395
    "    elif data_exist_len == 0:\n",
396
    "        # data = [to_fill_value for _ in range(data_len)]\n",
397
    "        for i in range(data_len):\n",
398
    "            data[i] = to_fill_value\n",
399
    "        return data\n",
400
    "    if pd.isna(data[0]):\n",
401
    "        # find the first non-nan value's position\n",
402
    "        not_na_pos = 0\n",
403
    "        for i in range(data_len):\n",
404
    "            if not pd.isna(data[i]):\n",
405
    "                not_na_pos = i\n",
406
    "                break\n",
407
    "        # fill element before the first non-nan value with median\n",
408
    "        for i in range(not_na_pos):\n",
409
    "            data[i] = to_fill_value\n",
410
    "    # fill element after the first non-nan value\n",
411
    "    for i in range(1, data_len):\n",
412
    "        if pd.isna(data[i]):\n",
413
    "            data[i] = data[i-1]\n",
414
    "    return data"
415
   ]
416
  },
417
  {
418
   "cell_type": "code",
419
   "execution_count": null,
420
   "metadata": {},
421
   "outputs": [],
422
   "source": [
423
    "# fill missing data using our strategy and convert to time series records\n",
424
    "grouped = df_train.groupby('PATIENT_ID')\n",
425
    "\n",
426
    "all_x_demographic = []\n",
427
    "all_x_labtest = []\n",
428
    "all_y = []\n",
429
    "all_missing_mask = []\n",
430
    "\n",
431
    "for name, group in grouped:\n",
432
    "    sorted_group = group.sort_values(by=['RE_DATE'], ascending=True)\n",
433
    "    patient_demographic = []\n",
434
    "    patient_labtest = []\n",
435
    "    patient_y = []\n",
436
    "    \n",
437
    "    for f in demographic_features+labtest_features:\n",
438
    "        to_fill_value = (statistic_info[f]['median'] - statistic_info[f]['mean'])/(statistic_info[f]['std']+1e-12)\n",
439
    "        # take median patient as the default to-fill missing value\n",
440
    "        # print(sorted_group[f].values)\n",
441
    "        fill_missing_value(sorted_group[f].values, to_fill_value)\n",
442
    "        # print(sorted_group[f].values)\n",
443
    "        # print('-----------')\n",
444
    "    all_missing_mask.append((np.isfinite(sorted_group[demographic_features+labtest_features].to_numpy())).astype(int))\n",
445
    "\n",
446
    "    for _, v in sorted_group.iterrows():\n",
447
    "        patient_y.append([v['outcome'], v['LOS']])\n",
448
    "        demo = []\n",
449
    "        lab = []\n",
450
    "        for f in demographic_features:\n",
451
    "            demo.append(v[f])\n",
452
    "        for f in labtest_features:\n",
453
    "            lab.append(v[f])\n",
454
    "        patient_labtest.append(lab)\n",
455
    "        patient_demographic.append(demo)\n",
456
    "    all_y.append(patient_y)\n",
457
    "    all_x_demographic.append(patient_demographic[-1])\n",
458
    "    all_x_labtest.append(patient_labtest)\n",
459
    "\n",
460
    "# all_x_demographic (2 dim, record each patients' demographic features)\n",
461
    "# all_x_labtest (3 dim, record each patients' lab test features)\n",
462
    "# all_y (3 dim, patients' outcome/los of all visits)"
463
   ]
464
  },
465
  {
466
   "cell_type": "code",
467
   "execution_count": null,
468
   "metadata": {},
469
   "outputs": [],
470
   "source": [
471
    "all_x_labtest = np.array(all_x_labtest, dtype=object)\n",
472
    "x_lab_length = [len(_) for _ in all_x_labtest]\n",
473
    "x_lab_length = torch.tensor(x_lab_length, dtype=torch.int)\n",
474
    "max_length = int(x_lab_length.max())\n",
475
    "all_x_labtest = [torch.tensor(_) for _ in all_x_labtest]\n",
476
    "# pad lab test sequence to the same shape\n",
477
    "all_x_labtest = torch.nn.utils.rnn.pad_sequence((all_x_labtest), batch_first=True)\n",
478
    "\n",
479
    "all_x_demographic = torch.tensor(all_x_demographic)\n",
480
    "batch_size, demo_dim = all_x_demographic.shape\n",
481
    "# repeat demographic tensor\n",
482
    "all_x_demographic = torch.reshape(all_x_demographic.repeat(1, max_length), (batch_size, max_length, demo_dim))\n",
483
    "# demographic tensor concat with lab test tensor\n",
484
    "all_x = torch.cat((all_x_demographic, all_x_labtest), 2)\n",
485
    "\n",
486
    "all_y = np.array(all_y, dtype=object)\n",
487
    "all_y = [torch.Tensor(_) for _ in all_y]\n",
488
    "# pad [outcome/los] sequence as well\n",
489
    "all_y = torch.nn.utils.rnn.pad_sequence((all_y), batch_first=True)\n",
490
    "\n",
491
    "all_missing_mask = np.array(all_missing_mask, dtype=object)\n",
492
    "all_missing_mask = [torch.tensor(_) for _ in all_missing_mask]\n",
493
    "all_missing_mask = torch.nn.utils.rnn.pad_sequence((all_missing_mask), batch_first=True)"
494
   ]
495
  },
496
  {
497
   "cell_type": "code",
498
   "execution_count": null,
499
   "metadata": {},
500
   "outputs": [],
501
   "source": [
502
    "# save pickle format dataset (export torch tensor)\n",
503
    "pd.to_pickle(all_x, f'./processed_data/x.pkl')\n",
504
    "pd.to_pickle(all_y, f'./processed_data/y.pkl')\n",
505
    "pd.to_pickle(x_lab_length, f'./processed_data/visits_length.pkl')\n",
506
    "pd.to_pickle(all_missing_mask, f'./processed_data/missing_mask.pkl')"
507
   ]
508
  },
509
  {
510
   "cell_type": "code",
511
   "execution_count": null,
512
   "metadata": {},
513
   "outputs": [],
514
   "source": [
515
    "# Calculate patients' outcome statistics (patients-wise)\n",
516
    "outcome_list = []\n",
517
    "y_outcome = all_y[:, :, 0]\n",
518
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
519
    "for i in indices:\n",
520
    "    outcome_list.append(y_outcome[i][0].item())\n",
521
    "outcome_list = np.array(outcome_list)\n",
522
    "print(len(outcome_list))\n",
523
    "unique, count=np.unique(outcome_list,return_counts=True)\n",
524
    "data_count=dict(zip(unique,count))\n",
525
    "print(data_count)"
526
   ]
527
  },
528
  {
529
   "cell_type": "code",
530
   "execution_count": null,
531
   "metadata": {},
532
   "outputs": [],
533
   "source": [
534
    "# Calculate patients' outcome statistics (records-wise)\n",
535
    "outcome_records_list = []\n",
536
    "y_outcome = all_y[:, :, 0]\n",
537
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
538
    "for i in indices:\n",
539
    "    outcome_records_list.extend(y_outcome[i][0:x_lab_length[i]].tolist())\n",
540
    "outcome_records_list = np.array(outcome_records_list)\n",
541
    "print(len(outcome_records_list))\n",
542
    "unique, count=np.unique(outcome_records_list,return_counts=True)\n",
543
    "data_count=dict(zip(unique,count))\n",
544
    "print(data_count)"
545
   ]
546
  },
547
  {
548
   "cell_type": "code",
549
   "execution_count": null,
550
   "metadata": {},
551
   "outputs": [],
552
   "source": [
553
    "# Calculate patients' mean los and 95% percentile los\n",
554
    "los_list = []\n",
555
    "y_los = all_y[:, :, 1]\n",
556
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
557
    "for i in indices:\n",
558
    "    # los_list.extend(y_los[i][: x_lab_length[i].long()].tolist())\n",
559
    "    los_list.append(y_los[i][0].item())\n",
560
    "los_list = np.array(los_list)\n",
561
    "print(los_list.mean() * 0.5)\n",
562
    "print(np.median(los_list) * 0.5)\n",
563
    "print(np.percentile(los_list, 95))\n",
564
    "\n",
565
    "print('median:', np.median(los_list))\n",
566
    "print('Q1:', np.percentile(los_list, 25))\n",
567
    "print('Q3:', np.percentile(los_list, 75))"
568
   ]
569
  },
570
  {
571
   "cell_type": "code",
572
   "execution_count": null,
573
   "metadata": {},
574
   "outputs": [],
575
   "source": [
576
    "los_alive_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 0])\n",
577
    "los_dead_list = np.array([los_list[i] for i in range(len(los_list)) if outcome_list[i] == 1])\n",
578
    "print(len(los_alive_list))\n",
579
    "print(len(los_dead_list))\n",
580
    "\n",
581
    "print('[Alive]')\n",
582
    "print('median:', np.median(los_alive_list))\n",
583
    "print('Q1:', np.percentile(los_alive_list, 25))\n",
584
    "print('Q3:', np.percentile(los_alive_list, 75))\n",
585
    "\n",
586
    "print('[Dead]')\n",
587
    "print('median:', np.median(los_dead_list))\n",
588
    "print('Q1:', np.percentile(los_dead_list, 25))\n",
589
    "print('Q3:', np.percentile(los_dead_list, 75))"
590
   ]
591
  },
592
  {
593
   "cell_type": "code",
594
   "execution_count": null,
595
   "metadata": {},
596
   "outputs": [],
597
   "source": [
598
    "tjh_los_statistics = {\n",
599
    "    'overall': los_list,\n",
600
    "    'alive': los_alive_list,\n",
601
    "    'dead': los_dead_list\n",
602
    "}\n",
603
    "# pd.to_pickle(tjh_los_statistics, 'tjh_los_statistics.pkl')"
604
   ]
605
  },
606
  {
607
   "cell_type": "code",
608
   "execution_count": null,
609
   "metadata": {},
610
   "outputs": [],
611
   "source": [
612
    "# calculate visits length Median [Q1, Q3]\n",
613
    "visits_list = np.array(x_lab_length)\n",
614
    "visits_alive_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 0])\n",
615
    "visits_dead_list = np.array([x_lab_length[i] for i in range(len(x_lab_length)) if outcome_list[i] == 1])\n",
616
    "print(len(visits_alive_list))\n",
617
    "print(len(visits_dead_list))\n",
618
    "\n",
619
    "print('[Total]')\n",
620
    "print('median:', np.median(visits_list))\n",
621
    "print('Q1:', np.percentile(visits_list, 25))\n",
622
    "print('Q3:', np.percentile(visits_list, 75))\n",
623
    "\n",
624
    "print('[Alive]')\n",
625
    "print('median:', np.median(visits_alive_list))\n",
626
    "print('Q1:', np.percentile(visits_alive_list, 25))\n",
627
    "print('Q3:', np.percentile(visits_alive_list, 75))\n",
628
    "\n",
629
    "print('[Dead]')\n",
630
    "print('median:', np.median(visits_dead_list))\n",
631
    "print('Q1:', np.percentile(visits_dead_list, 25))\n",
632
    "print('Q3:', np.percentile(visits_dead_list, 75))"
633
   ]
634
  },
635
  {
636
   "cell_type": "code",
637
   "execution_count": null,
638
   "metadata": {},
639
   "outputs": [],
640
   "source": [
641
    "# Length-of-stay interval (overall/alive/dead)\n",
642
    "los_interval_list = []\n",
643
    "los_interval_alive_list = []\n",
644
    "los_interval_dead_list = []\n",
645
    "\n",
646
    "y_los = all_y[:, :, 1]\n",
647
    "indices = torch.arange(len(x_lab_length), dtype=torch.int64)\n",
648
    "for i in indices:\n",
649
    "    cur_visits_len = x_lab_length[i]\n",
650
    "    if cur_visits_len == 1:\n",
651
    "        continue\n",
652
    "    for j in range(1, cur_visits_len):\n",
653
    "        los_interval_list.append(y_los[i][j-1]-y_los[i][j])\n",
654
    "        if outcome_list[i] == 0:\n",
655
    "            los_interval_alive_list.append(y_los[i][j-1]-y_los[i][j])\n",
656
    "        else:\n",
657
    "            los_interval_dead_list.append(y_los[i][j-1]-y_los[i][j])\n",
658
    "\n",
659
    "los_interval_list = np.array(los_interval_list)\n",
660
    "los_interval_alive_list = np.array(los_interval_alive_list)\n",
661
    "los_interval_dead_list = np.array(los_interval_dead_list)\n",
662
    "\n",
663
    "output = {\n",
664
    "    'overall': los_interval_list,\n",
665
    "    'alive': los_interval_alive_list,\n",
666
    "    'dead': los_interval_dead_list,\n",
667
    "}\n",
668
    "# pd.to_pickle(output, 'tjh_los_interval_list.pkl')"
669
   ]
670
  },
671
  {
672
   "cell_type": "code",
673
   "execution_count": null,
674
   "metadata": {},
675
   "outputs": [],
676
   "source": [
677
    "len(los_interval_list), len(los_interval_alive_list), len(los_interval_dead_list)"
678
   ]
679
  },
680
  {
681
   "cell_type": "code",
682
   "execution_count": null,
683
   "metadata": {},
684
   "outputs": [],
685
   "source": [
686
    "def check_nan(x):\n",
687
    "    if np.isnan(np.sum(x.cpu().numpy())):\n",
688
    "        print(\"some values from input are nan\")\n",
689
    "    else:\n",
690
    "        print(\"no nan\")"
691
   ]
692
  },
693
  {
694
   "cell_type": "markdown",
695
   "metadata": {},
696
   "source": [
697
    "# Draw Charts"
698
   ]
699
  },
700
  {
701
   "cell_type": "code",
702
   "execution_count": null,
703
   "metadata": {},
704
   "outputs": [],
705
   "source": [
706
    "import matplotlib.pyplot as plt\n",
707
    "from matplotlib.ticker import PercentFormatter\n",
708
    "import matplotlib.font_manager as font_manager\n",
709
    "import pandas as pd\n",
710
    "import numpy as np\n",
711
    "\n",
712
    "plt.style.use('seaborn-whitegrid')\n",
713
    "color = 'cornflowerblue'\n",
714
    "ec = 'None'\n",
715
    "alpha=0.5\n",
716
    "alive_color = 'olivedrab'\n",
717
    "dead_color = 'orchid'"
718
   ]
719
  },
720
  {
721
   "cell_type": "code",
722
   "execution_count": null,
723
   "metadata": {},
724
   "outputs": [],
725
   "source": [
726
    "tj_overall = pd.read_csv('./tjh_data_raw.csv')\n",
727
    "tj_overall.head()"
728
   ]
729
  },
730
  {
731
   "cell_type": "code",
732
   "execution_count": null,
733
   "metadata": {},
734
   "outputs": [],
735
   "source": [
736
    "tj_alive = tj_overall.loc[tj_overall['outcome'] == 0]\n",
737
    "tj_dead = tj_overall.loc[tj_overall['outcome'] == 1]"
738
   ]
739
  },
740
  {
741
   "cell_type": "code",
742
   "execution_count": null,
743
   "metadata": {},
744
   "outputs": [],
745
   "source": [
746
    "tj_overall.describe().to_csv('tjh_describe.csv', index=False)"
747
   ]
748
  },
749
  {
750
   "cell_type": "code",
751
   "execution_count": null,
752
   "metadata": {},
753
   "outputs": [],
754
   "source": [
755
    "limit = 0.05\n",
756
    "\n",
757
    "from matplotlib.ticker import PercentFormatter\n",
758
    "import matplotlib.font_manager as font_manager\n",
759
    "plt.style.use('seaborn-whitegrid')\n",
760
    "color = 'cornflowerblue'\n",
761
    "ec = 'None'\n",
762
    "alive_color = 'olivedrab'\n",
763
    "# dead_color = 'mediumslateblue'\n",
764
    "dead_color = 'orchid'\n",
765
    "alpha=0.5\n",
766
    "\n",
767
    "csfont = {'fontname':'Times New Roman', 'fontsize': 18}\n",
768
    "font = 'Times New Roman'\n",
769
    "fig=plt.figure(figsize=(16,12), dpi= 500, facecolor='w', edgecolor='k')\n",
770
    "\n",
771
    "idx = 1\n",
772
    "\n",
773
    "key = 'age'\n",
774
    "low = tj_overall[key].quantile(limit)\n",
775
    "high = tj_overall[key].quantile(1 - limit)\n",
776
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
777
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
778
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
779
    "ax = plt.subplot(4, 4, idx)\n",
780
    "ax.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha, label='overall')\n",
781
    "plt.xlabel('Age',**csfont)\n",
782
    "plt.ylabel('Percentage',**csfont)\n",
783
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
784
    "ax.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2, label='alive')\n",
785
    "ax.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2, label='dead')\n",
786
    "plt.xticks(**csfont)\n",
787
    "plt.yticks(**csfont)\n",
788
    "idx += 1\n",
789
    "\n",
790
    "key = 'White blood cell count'\n",
791
    "low = tj_overall[key].quantile(limit)\n",
792
    "high = tj_overall[key].quantile(1 - limit)\n",
793
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
794
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
795
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
796
    "plt.subplot(4, 4, idx)\n",
797
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
798
    "plt.xlabel(key,**csfont)\n",
799
    "plt.ylabel('Percentage',**csfont)\n",
800
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
801
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
802
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
803
    "plt.xticks(**csfont)\n",
804
    "plt.yticks(**csfont)\n",
805
    "idx += 1\n",
806
    "\n",
807
    "key = 'Red blood cell count'\n",
808
    "low = tj_overall[key].quantile(limit)\n",
809
    "high = tj_overall[key].quantile(1 - limit)\n",
810
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
811
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
812
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
813
    "plt.subplot(4, 4, idx)\n",
814
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
815
    "plt.xlabel(key,**csfont)\n",
816
    "plt.ylabel('Percentage',**csfont)\n",
817
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
818
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
819
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
820
    "plt.xticks(**csfont)\n",
821
    "plt.yticks(**csfont)\n",
822
    "idx += 1\n",
823
    "\n",
824
    "key = 'neutrophils(%)'\n",
825
    "low = tj_overall[key].quantile(limit)\n",
826
    "high = tj_overall[key].quantile(1 - limit)\n",
827
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
828
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
829
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
830
    "plt.subplot(4, 4, idx)\n",
831
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
832
    "plt.xlabel('neutrophils %',**csfont)\n",
833
    "plt.ylabel('Percentage',**csfont)\n",
834
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
835
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
836
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
837
    "plt.xticks(**csfont)\n",
838
    "plt.yticks(**csfont)\n",
839
    "idx += 1\n",
840
    "\n",
841
    "key = '(%)lymphocyte'\n",
842
    "low = tj_overall[key].quantile(limit)\n",
843
    "high = tj_overall[key].quantile(1 - limit)\n",
844
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
845
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
846
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
847
    "plt.subplot(4, 4, idx)\n",
848
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
849
    "plt.xlabel('lymphocyte %',**csfont)\n",
850
    "plt.ylabel('Percentage',**csfont)\n",
851
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
852
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
853
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
854
    "plt.xticks(**csfont)\n",
855
    "plt.yticks(**csfont)\n",
856
    "idx += 1\n",
857
    "\n",
858
    "key = 'monocytes(%)'\n",
859
    "low = tj_overall[key].quantile(limit)\n",
860
    "high = tj_overall[key].quantile(1 - limit)\n",
861
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
862
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
863
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
864
    "plt.subplot(4, 4, idx)\n",
865
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
866
    "plt.xlabel(key,**csfont)\n",
867
    "plt.ylabel('Percentage',**csfont)\n",
868
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
869
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
870
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
871
    "plt.xticks(**csfont)\n",
872
    "plt.yticks(**csfont)\n",
873
    "idx += 1\n",
874
    "\n",
875
    "key = 'Platelet count'\n",
876
    "low = tj_overall[key].quantile(limit)\n",
877
    "high = tj_overall[key].quantile(1 - limit)\n",
878
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
879
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
880
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
881
    "plt.subplot(4, 4, idx)\n",
882
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
883
    "plt.xlabel(key,**csfont)\n",
884
    "plt.ylabel('Percentage',**csfont)\n",
885
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
886
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
887
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
888
    "plt.xticks(**csfont)\n",
889
    "plt.yticks(**csfont)\n",
890
    "idx += 1\n",
891
    "\n",
892
    "key = 'lymphocyte count'\n",
893
    "low = tj_overall[key].quantile(limit)\n",
894
    "high = tj_overall[key].quantile(1 - limit)\n",
895
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
896
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
897
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
898
    "plt.subplot(4, 4, idx)\n",
899
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
900
    "plt.xlabel('Lymphocyte count',**csfont)\n",
901
    "plt.ylabel('Percentage',**csfont)\n",
902
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
903
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
904
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
905
    "plt.xticks(**csfont)\n",
906
    "plt.yticks(**csfont)\n",
907
    "idx += 1\n",
908
    "\n",
909
    "key = 'hemoglobin'\n",
910
    "low = tj_overall[key].quantile(limit)\n",
911
    "high = tj_overall[key].quantile(1 - limit)\n",
912
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
913
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
914
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
915
    "plt.subplot(4, 4, idx)\n",
916
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
917
    "plt.xlabel('Hemoglobin',**csfont)\n",
918
    "plt.ylabel('Percentage',**csfont)\n",
919
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
920
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
921
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
922
    "plt.xticks(**csfont)\n",
923
    "plt.yticks(**csfont)\n",
924
    "idx += 1\n",
925
    "\n",
926
    "key = 'calcium'\n",
927
    "low = tj_overall[key].quantile(limit)\n",
928
    "high = tj_overall[key].quantile(1 - limit)\n",
929
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
930
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
931
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
932
    "plt.subplot(4, 4, idx)\n",
933
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
934
    "plt.xlabel('Calcium',**csfont)\n",
935
    "plt.ylabel('Percentage',**csfont)\n",
936
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
937
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
938
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
939
    "plt.xticks(**csfont)\n",
940
    "plt.yticks(**csfont)\n",
941
    "idx += 1\n",
942
    "\n",
943
    "key = 'hematocrit'\n",
944
    "low = tj_overall[key].quantile(limit)\n",
945
    "high = tj_overall[key].quantile(1 - limit)\n",
946
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
947
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
948
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
949
    "plt.subplot(4, 4, idx)\n",
950
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
951
    "plt.xlabel('Hematocrit',**csfont)\n",
952
    "plt.ylabel('Percentage',**csfont)\n",
953
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
954
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
955
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
956
    "plt.xticks(**csfont)\n",
957
    "plt.yticks(**csfont)\n",
958
    "idx += 1\n",
959
    "\n",
960
    "key = 'albumin'\n",
961
    "low = tj_overall[key].quantile(limit)\n",
962
    "high = tj_overall[key].quantile(1 - limit)\n",
963
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
964
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
965
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
966
    "plt.subplot(4, 4, idx)\n",
967
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
968
    "plt.xlabel('Albumin',**csfont)\n",
969
    "plt.ylabel('Percentage',**csfont)\n",
970
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
971
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
972
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
973
    "plt.xticks(**csfont)\n",
974
    "plt.yticks(**csfont)\n",
975
    "idx += 1\n",
976
    "\n",
977
    "key = 'neutrophils count'\n",
978
    "low = tj_overall[key].quantile(limit)\n",
979
    "high = tj_overall[key].quantile(1 - limit)\n",
980
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
981
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
982
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
983
    "plt.subplot(4, 4, idx)\n",
984
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
985
    "plt.xlabel('Neutrophils count',**csfont)\n",
986
    "plt.ylabel('Percentage',**csfont)\n",
987
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
988
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
989
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
990
    "plt.xticks(**csfont)\n",
991
    "plt.yticks(**csfont)\n",
992
    "idx += 1\n",
993
    "\n",
994
    "key = 'monocytes count'\n",
995
    "low = tj_overall[key].quantile(limit)\n",
996
    "high = tj_overall[key].quantile(1 - limit)\n",
997
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
998
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
999
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
1000
    "plt.subplot(4, 4, idx)\n",
1001
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
1002
    "plt.xlabel('Monocytes count',**csfont)\n",
1003
    "plt.ylabel('Percentage',**csfont)\n",
1004
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
1005
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1006
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1007
    "plt.xticks(**csfont)\n",
1008
    "plt.yticks(**csfont)\n",
1009
    "idx += 1\n",
1010
    "\n",
1011
    "key = 'basophil count(#)'\n",
1012
    "low = tj_overall[key].quantile(limit)\n",
1013
    "high = tj_overall[key].quantile(1 - limit)\n",
1014
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
1015
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
1016
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
1017
    "plt.subplot(4, 4, idx)\n",
1018
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]),  color=color, ec=ec, alpha=alpha)\n",
1019
    "plt.xlabel('Basophil count',**csfont)\n",
1020
    "plt.ylabel('Percentage',**csfont)\n",
1021
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
1022
    "plt.subplot(4, 4, idx)\n",
1023
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]),  color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1024
    "plt.subplot(4, 4, idx)\n",
1025
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]),  color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1026
    "plt.xticks(**csfont)\n",
1027
    "plt.yticks(**csfont)\n",
1028
    "idx += 1\n",
1029
    "\n",
1030
    "key = 'eosinophils(%)'\n",
1031
    "low = tj_overall[key].quantile(limit)\n",
1032
    "high = tj_overall[key].quantile(1 - limit)\n",
1033
    "tj_AGE_overall = tj_overall[tj_overall[key].between(low, high)]\n",
1034
    "tj_AGE_dead = tj_dead[tj_dead[key].between(low, high)]\n",
1035
    "tj_AGE_alive = tj_alive[tj_alive[key].between(low, high)]\n",
1036
    "plt.subplot(4, 4, idx)\n",
1037
    "plt.hist(tj_AGE_overall[key], bins=20, weights=np.ones(len(tj_AGE_overall[key])) / len(tj_AGE_overall[key]), color=color, ec=ec, alpha=alpha)\n",
1038
    "plt.xlabel('Eosinophils %',**csfont)\n",
1039
    "plt.ylabel('Percentage',**csfont)\n",
1040
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
1041
    "plt.subplot(4, 4, idx)\n",
1042
    "plt.hist(tj_AGE_alive[key], bins=20, weights=np.ones(len(tj_AGE_alive[key])) / len(tj_AGE_alive[key]), color='green', ec=alive_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1043
    "plt.subplot(4, 4, idx)\n",
1044
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
1045
    "plt.hist(tj_AGE_dead[key], bins=20, weights=np.ones(len(tj_AGE_dead[key])) / len(tj_AGE_dead[key]), color='green', ec=dead_color, alpha=1, histtype=\"step\", linewidth=2)\n",
1046
    "plt.gca().yaxis.set_major_formatter(PercentFormatter(1))\n",
1047
    "plt.xticks(**csfont)\n",
1048
    "plt.yticks(**csfont)\n",
1049
    "idx += 1\n",
1050
    "\n",
1051
    "handles, labels = ax.get_legend_handles_labels()\n",
1052
    "print(handles, labels)\n",
1053
    "plt.figlegend(handles, labels, loc='upper center', ncol=5, fontsize=18, bbox_to_anchor=(0.5, 1.05), prop=font_manager.FontProperties(family='Times New Roman',\n",
1054
    "                                   style='normal', size=18))\n",
1055
    "\n",
1056
    "fig.tight_layout()\n",
1057
    "plt.show()"
1058
   ]
1059
  }
1060
 ],
1061
 "metadata": {
1062
  "kernelspec": {
1063
   "display_name": "Python 3.7.11 ('python37')",
1064
   "language": "python",
1065
   "name": "python3"
1066
  },
1067
  "language_info": {
1068
   "codemirror_mode": {
1069
    "name": "ipython",
1070
    "version": 3
1071
   },
1072
   "file_extension": ".py",
1073
   "mimetype": "text/x-python",
1074
   "name": "python",
1075
   "nbconvert_exporter": "python",
1076
   "pygments_lexer": "ipython3",
1077
   "version": "3.7.11"
1078
  },
1079
  "vscode": {
1080
   "interpreter": {
1081
    "hash": "a10b846bdc9fc41ee38835cbc29d70b69dd5fd54e1341ea2c410a7804a50447a"
1082
   }
1083
  }
1084
 },
1085
 "nbformat": 4,
1086
 "nbformat_minor": 2
1087
}