Switch to unified view

a b/src/preprocess/05_data_split.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "id": "debdace9",
6
   "metadata": {},
7
   "source": [
8
    "import os\n",
9
    "import sys\n",
10
    "\n",
11
    "src_path = os.path.abspath(\"../..\")\n",
12
    "print(src_path)\n",
13
    "sys.path.append(src_path)"
14
   ],
15
   "outputs": [],
16
   "execution_count": null
17
  },
18
  {
19
   "cell_type": "code",
20
   "id": "6bad1e09",
21
   "metadata": {},
22
   "source": [
23
    "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed"
24
   ],
25
   "outputs": [],
26
   "execution_count": null
27
  },
28
  {
29
   "cell_type": "code",
30
   "id": "5d9bc78c",
31
   "metadata": {},
32
   "source": [
33
    "set_seed(seed=42)"
34
   ],
35
   "outputs": [],
36
   "execution_count": null
37
  },
38
  {
39
   "cell_type": "code",
40
   "id": "13d22a57",
41
   "metadata": {},
42
   "source": [
43
    "import pandas as pd"
44
   ],
45
   "outputs": [],
46
   "execution_count": null
47
  },
48
  {
49
   "cell_type": "code",
50
   "id": "dd9852d5",
51
   "metadata": {},
52
   "source": [
53
    "mimic_iv_path = os.path.join(raw_data_path, \"physionet.org/files/mimiciv/2.2\")\n",
54
    "output_path = os.path.join(processed_data_path, \"mimic4\")"
55
   ],
56
   "outputs": [],
57
   "execution_count": null
58
  },
59
  {
60
   "cell_type": "code",
61
   "id": "b6a27998",
62
   "metadata": {},
63
   "source": [
64
    "cohort = pd.read_csv(os.path.join(output_path, \"cohort+len.csv\"))\n",
65
    "print(cohort.shape)\n",
66
    "cohort.head()"
67
   ],
68
   "outputs": [],
69
   "execution_count": null
70
  },
71
  {
72
   "cell_type": "code",
73
   "id": "9dd92e23",
74
   "metadata": {},
75
   "source": [
76
    "cohort[\"hadm_intime\"] = pd.to_datetime(cohort[\"hadm_intime\"])\n",
77
    "cohort[\"hadm_outtime\"] = pd.to_datetime(cohort[\"hadm_outtime\"])\n",
78
    "cohort[\"stay_intime\"] = pd.to_datetime(cohort[\"stay_intime\"])\n",
79
    "cohort[\"stay_outtime\"] = pd.to_datetime(cohort[\"stay_outtime\"])"
80
   ],
81
   "outputs": [],
82
   "execution_count": null
83
  },
84
  {
85
   "cell_type": "code",
86
   "id": "8f55c793",
87
   "metadata": {},
88
   "source": [
89
    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
90
    "len(hadm_ids)"
91
   ],
92
   "outputs": [],
93
   "execution_count": null
94
  },
95
  {
96
   "cell_type": "code",
97
   "id": "26459eda",
98
   "metadata": {},
99
   "source": [
100
    "import ast\n",
101
    "import numpy as np\n",
102
    "\n",
103
    "\n",
104
    "def safe_literal_eval(s):\n",
105
    "    if pd.isna(s):\n",
106
    "        return np.nan\n",
107
    "    return ast.literal_eval(s)\n",
108
    "\n",
109
    "\n",
110
    "cohort.label_diagnosis = cohort.label_diagnosis.apply(safe_literal_eval)"
111
   ],
112
   "outputs": [],
113
   "execution_count": null
114
  },
115
  {
116
   "cell_type": "code",
117
   "id": "99d2b8c8",
118
   "metadata": {},
119
   "source": [
120
    "qa_note = pd.read_json(os.path.join(output_path, \"qa_note.jsonl\"), lines = True)\n",
121
    "qa_note"
122
   ],
123
   "outputs": [],
124
   "execution_count": null
125
  },
126
  {
127
   "cell_type": "code",
128
   "id": "741b6ed7",
129
   "metadata": {},
130
   "source": [
131
    "qa_note.hadm_id.nunique()"
132
   ],
133
   "outputs": [],
134
   "execution_count": null
135
  },
136
  {
137
   "cell_type": "code",
138
   "id": "7d287db0",
139
   "metadata": {},
140
   "source": [
141
    "qa_event = pd.read_json(os.path.join(output_path, \"qa_event.jsonl\"), lines = True)\n",
142
    "qa_event"
143
   ],
144
   "outputs": [],
145
   "execution_count": null
146
  },
147
  {
148
   "cell_type": "code",
149
   "id": "dbc2af95",
150
   "metadata": {},
151
   "source": [
152
    "qa_event.hadm_id.nunique()"
153
   ],
154
   "outputs": [],
155
   "execution_count": null
156
  },
157
  {
158
   "cell_type": "code",
159
   "id": "b4bd715e",
160
   "metadata": {},
161
   "source": [
162
    "qa_event.event_type.value_counts()"
163
   ],
164
   "outputs": [],
165
   "execution_count": null
166
  },
167
  {
168
   "cell_type": "markdown",
169
   "id": "5a9d60c6",
170
   "metadata": {},
171
   "source": [
172
    "helper"
173
   ]
174
  },
175
  {
176
   "cell_type": "code",
177
   "id": "5171bbae",
178
   "metadata": {},
179
   "source": [
180
    "from concurrent.futures import ThreadPoolExecutor\n",
181
    "from tqdm import tqdm\n",
182
    "from pandarallel import pandarallel"
183
   ],
184
   "outputs": [],
185
   "execution_count": null
186
  },
187
  {
188
   "cell_type": "code",
189
   "id": "5d6b9ce2",
190
   "metadata": {},
191
   "source": [
192
    "pandarallel.initialize(progress_bar=True)"
193
   ],
194
   "outputs": [],
195
   "execution_count": null
196
  },
197
  {
198
   "cell_type": "markdown",
199
   "id": "941b116b",
200
   "metadata": {},
201
   "source": [
202
    "stat"
203
   ]
204
  },
205
  {
206
   "cell_type": "code",
207
   "id": "503a9cbd",
208
   "metadata": {},
209
   "source": [
210
    "cohort = cohort[cohort.hadm_id.isin(qa_note.hadm_id.unique())]\n",
211
    "len(cohort)"
212
   ],
213
   "outputs": [],
214
   "execution_count": null
215
  },
216
  {
217
   "cell_type": "code",
218
   "id": "b4e602ba",
219
   "metadata": {},
220
   "source": [
221
    "cohort = cohort[cohort.hadm_id.isin(qa_event.hadm_id.unique())]\n",
222
    "len(cohort)"
223
   ],
224
   "outputs": [],
225
   "execution_count": null
226
  },
227
  {
228
   "cell_type": "code",
229
   "id": "5014c99f",
230
   "metadata": {
231
    "scrolled": false
232
   },
233
   "source": [
234
    "cohort.hadm_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
235
   ],
236
   "outputs": [],
237
   "execution_count": null
238
  },
239
  {
240
   "cell_type": "code",
241
   "id": "11f5935f",
242
   "metadata": {},
243
   "source": [
244
    "548.490833 / 24"
245
   ],
246
   "outputs": [],
247
   "execution_count": null
248
  },
249
  {
250
   "cell_type": "code",
251
   "id": "87a1b947",
252
   "metadata": {},
253
   "source": [
254
    "cohort.stay_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
255
   ],
256
   "outputs": [],
257
   "execution_count": null
258
  },
259
  {
260
   "cell_type": "code",
261
   "id": "f774a307",
262
   "metadata": {},
263
   "source": [
264
    "265.649889 / 24"
265
   ],
266
   "outputs": [],
267
   "execution_count": null
268
  },
269
  {
270
   "cell_type": "code",
271
   "id": "99d6a2f0",
272
   "metadata": {},
273
   "source": [
274
    "cohort.len_selected.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
275
   ],
276
   "outputs": [],
277
   "execution_count": null
278
  },
279
  {
280
   "cell_type": "code",
281
   "id": "622669a8",
282
   "metadata": {},
283
   "source": [
284
    "cohort_filtered = cohort[cohort.len_selected <= 1256.650000]\n",
285
    "cohort_filtered"
286
   ],
287
   "outputs": [],
288
   "execution_count": null
289
  },
290
  {
291
   "cell_type": "code",
292
   "id": "75ddbce8",
293
   "metadata": {},
294
   "source": [
295
    "cohort_filtered.hadm_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
296
   ],
297
   "outputs": [],
298
   "execution_count": null
299
  },
300
  {
301
   "cell_type": "code",
302
   "id": "aaed6fb7",
303
   "metadata": {},
304
   "source": [
305
    "cohort_filtered.stay_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
306
   ],
307
   "outputs": [],
308
   "execution_count": null
309
  },
310
  {
311
   "cell_type": "code",
312
   "id": "e2cca302",
313
   "metadata": {},
314
   "source": [
315
    "cohort_filtered.len_selected.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
316
   ],
317
   "outputs": [],
318
   "execution_count": null
319
  },
320
  {
321
   "cell_type": "code",
322
   "id": "bc283fd9",
323
   "metadata": {},
324
   "source": [
325
    "all_patients = cohort_filtered.subject_id.unique()\n",
326
    "len(all_patients)"
327
   ],
328
   "outputs": [],
329
   "execution_count": null
330
  },
331
  {
332
   "cell_type": "code",
333
   "id": "18aa0954",
334
   "metadata": {},
335
   "source": [
336
    "from sklearn.model_selection import train_test_split\n",
337
    "\n",
338
    "\n",
339
    "train_val_patients, test_patients = train_test_split(all_patients, test_size=0.1, random_state=42)\n",
340
    "train_patients, val_patients = train_test_split(all_patients, test_size=0.111, random_state=42)"
341
   ],
342
   "outputs": [],
343
   "execution_count": null
344
  },
345
  {
346
   "cell_type": "code",
347
   "id": "d5d0a3ac",
348
   "metadata": {},
349
   "source": [
350
    "print(train_patients.shape)\n",
351
    "print(val_patients.shape)\n",
352
    "print(test_patients.shape)"
353
   ],
354
   "outputs": [],
355
   "execution_count": null
356
  },
357
  {
358
   "cell_type": "code",
359
   "id": "57dc6d8f",
360
   "metadata": {},
361
   "source": [
362
    "train = cohort_filtered[cohort_filtered.subject_id.isin(train_patients)].reset_index(drop=True)\n",
363
    "val = cohort_filtered[cohort_filtered.subject_id.isin(val_patients)].reset_index(drop=True)\n",
364
    "test = cohort_filtered[cohort_filtered.subject_id.isin(test_patients)].reset_index(drop=True)"
365
   ],
366
   "outputs": [],
367
   "execution_count": null
368
  },
369
  {
370
   "cell_type": "code",
371
   "id": "860d6a39",
372
   "metadata": {},
373
   "source": [
374
    "print(train.shape)\n",
375
    "print(val.shape)\n",
376
    "print(test.shape)"
377
   ],
378
   "outputs": [],
379
   "execution_count": null
380
  },
381
  {
382
   "cell_type": "code",
383
   "id": "ce3c7f18",
384
   "metadata": {},
385
   "source": [
386
    "train.to_csv(os.path.join(output_path, \"cohort_train.csv\"), index=False)\n",
387
    "val.to_csv(os.path.join(output_path, \"cohort_val.csv\"), index=False)\n",
388
    "test.to_csv(os.path.join(output_path, \"cohort_test.csv\"), index=False)"
389
   ],
390
   "outputs": [],
391
   "execution_count": null
392
  },
393
  {
394
   "cell_type": "code",
395
   "id": "7b0ea788",
396
   "metadata": {},
397
   "source": [
398
    "_, test_subset_patients = train_test_split(test_patients, test_size=100, random_state=42)\n",
399
    "len(test_subset_patients)"
400
   ],
401
   "outputs": [],
402
   "execution_count": null
403
  },
404
  {
405
   "cell_type": "code",
406
   "id": "5ba82c66",
407
   "metadata": {},
408
   "source": [
409
    "test_subset = test[test.subject_id.isin(test_subset_patients)].groupby(\"subject_id\").apply(lambda x: x.sample(1)).reset_index(drop=True)\n",
410
    "test_subset"
411
   ],
412
   "outputs": [],
413
   "execution_count": null
414
  },
415
  {
416
   "metadata": {},
417
   "cell_type": "code",
418
   "source": "test_subset.len_selected.describe()",
419
   "id": "0bd73438",
420
   "outputs": [],
421
   "execution_count": null
422
  },
423
  {
424
   "cell_type": "code",
425
   "id": "a58143a2",
426
   "metadata": {},
427
   "source": [
428
    "print(test_subset.subject_id.nunique())\n",
429
    "print(test_subset.hadm_id.nunique())"
430
   ],
431
   "outputs": [],
432
   "execution_count": null
433
  },
434
  {
435
   "cell_type": "code",
436
   "id": "745f82b5",
437
   "metadata": {},
438
   "source": [
439
    "test_subset.to_csv(os.path.join(output_path, \"cohort_test_subset.csv\"), index=False)"
440
   ],
441
   "outputs": [],
442
   "execution_count": null
443
  },
444
  {
445
   "cell_type": "code",
446
   "id": "7ba15669",
447
   "metadata": {},
448
   "source": [
449
    "qa_note_test_subset = qa_note[qa_note.hadm_id.isin(test_subset.hadm_id.unique())]\n",
450
    "qa_note_test_subset"
451
   ],
452
   "outputs": [],
453
   "execution_count": null
454
  },
455
  {
456
   "cell_type": "code",
457
   "id": "040ce4fb",
458
   "metadata": {},
459
   "source": [
460
    "qa_event_test_subset = qa_event[qa_event.hadm_id.isin(test_subset.hadm_id.unique())].groupby(\"hadm_id\").apply(lambda x: x.sample(1)).reset_index(drop=True)\n",
461
    "qa_event_test_subset"
462
   ],
463
   "outputs": [],
464
   "execution_count": null
465
  },
466
  {
467
   "cell_type": "code",
468
   "id": "5455dbca",
469
   "metadata": {},
470
   "source": [
471
    "qa_event_test_subset.hadm_id.nunique()"
472
   ],
473
   "outputs": [],
474
   "execution_count": null
475
  },
476
  {
477
   "cell_type": "code",
478
   "id": "4c78aa45",
479
   "metadata": {},
480
   "source": [
481
    "qa_event_test_subset.event_type.value_counts()"
482
   ],
483
   "outputs": [],
484
   "execution_count": null
485
  },
486
  {
487
   "cell_type": "code",
488
   "id": "9942e684",
489
   "metadata": {},
490
   "source": [
491
    "qa_test_subset = pd.concat([qa_event_test_subset, qa_note_test_subset]).reset_index(drop=True)\n",
492
    "qa_test_subset"
493
   ],
494
   "outputs": [],
495
   "execution_count": null
496
  },
497
  {
498
   "cell_type": "code",
499
   "id": "8366ae94",
500
   "metadata": {},
501
   "source": [
502
    "qa_test_subset.to_csv(os.path.join(output_path, \"qa_test_subset.csv\"), index=False)"
503
   ],
504
   "outputs": [],
505
   "execution_count": null
506
  },
507
  {
508
   "cell_type": "code",
509
   "id": "410a34c3",
510
   "metadata": {},
511
   "source": [],
512
   "outputs": [],
513
   "execution_count": null
514
  }
515
 ],
516
 "metadata": {
517
  "kernelspec": {
518
   "display_name": "pytorch20",
519
   "language": "python",
520
   "name": "pytorch20"
521
  },
522
  "language_info": {
523
   "codemirror_mode": {
524
    "name": "ipython",
525
    "version": 3
526
   },
527
   "file_extension": ".py",
528
   "mimetype": "text/x-python",
529
   "name": "python",
530
   "nbconvert_exporter": "python",
531
   "pygments_lexer": "ipython3",
532
   "version": "3.9.19"
533
  }
534
 },
535
 "nbformat": 4,
536
 "nbformat_minor": 5
537
}