--- a
+++ b/src/preprocess/05_data_split.ipynb
@@ -0,0 +1,537 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "id": "debdace9",
+   "metadata": {},
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "\n",
+    "src_path = os.path.abspath(\"../..\")\n",
+    "print(src_path)\n",
+    "sys.path.append(src_path)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "6bad1e09",
+   "metadata": {},
+   "source": [
+    "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "5d9bc78c",
+   "metadata": {},
+   "source": [
+    "set_seed(seed=42)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "13d22a57",
+   "metadata": {},
+   "source": [
+    "import pandas as pd"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "dd9852d5",
+   "metadata": {},
+   "source": [
+    "mimic_iv_path = os.path.join(raw_data_path, \"physionet.org/files/mimiciv/2.2\")\n",
+    "output_path = os.path.join(processed_data_path, \"mimic4\")"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "b6a27998",
+   "metadata": {},
+   "source": [
+    "cohort = pd.read_csv(os.path.join(output_path, \"cohort+len.csv\"))\n",
+    "print(cohort.shape)\n",
+    "cohort.head()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "9dd92e23",
+   "metadata": {},
+   "source": [
+    "cohort[\"hadm_intime\"] = pd.to_datetime(cohort[\"hadm_intime\"])\n",
+    "cohort[\"hadm_outtime\"] = pd.to_datetime(cohort[\"hadm_outtime\"])\n",
+    "cohort[\"stay_intime\"] = pd.to_datetime(cohort[\"stay_intime\"])\n",
+    "cohort[\"stay_outtime\"] = pd.to_datetime(cohort[\"stay_outtime\"])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "8f55c793",
+   "metadata": {},
+   "source": [
+    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
+    "len(hadm_ids)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "26459eda",
+   "metadata": {},
+   "source": [
+    "import ast\n",
+    "import numpy as np\n",
+    "\n",
+    "\n",
+    "def safe_literal_eval(s):\n",
+    "    if pd.isna(s):\n",
+    "        return np.nan\n",
+    "    return ast.literal_eval(s)\n",
+    "\n",
+    "\n",
+    "cohort.label_diagnosis = cohort.label_diagnosis.apply(safe_literal_eval)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "99d2b8c8",
+   "metadata": {},
+   "source": [
+    "qa_note = pd.read_json(os.path.join(output_path, \"qa_note.jsonl\"), lines = True)\n",
+    "qa_note"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "741b6ed7",
+   "metadata": {},
+   "source": [
+    "qa_note.hadm_id.nunique()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "7d287db0",
+   "metadata": {},
+   "source": [
+    "qa_event = pd.read_json(os.path.join(output_path, \"qa_event.jsonl\"), lines = True)\n",
+    "qa_event"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "dbc2af95",
+   "metadata": {},
+   "source": [
+    "qa_event.hadm_id.nunique()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "b4bd715e",
+   "metadata": {},
+   "source": [
+    "qa_event.event_type.value_counts()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5a9d60c6",
+   "metadata": {},
+   "source": [
+    "helper"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "id": "5171bbae",
+   "metadata": {},
+   "source": [
+    "from concurrent.futures import ThreadPoolExecutor\n",
+    "from tqdm import tqdm\n",
+    "from pandarallel import pandarallel"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "5d6b9ce2",
+   "metadata": {},
+   "source": [
+    "pandarallel.initialize(progress_bar=True)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "markdown",
+   "id": "941b116b",
+   "metadata": {},
+   "source": [
+    "stat"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "id": "503a9cbd",
+   "metadata": {},
+   "source": [
+    "cohort = cohort[cohort.hadm_id.isin(qa_note.hadm_id.unique())]\n",
+    "len(cohort)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "b4e602ba",
+   "metadata": {},
+   "source": [
+    "cohort = cohort[cohort.hadm_id.isin(qa_event.hadm_id.unique())]\n",
+    "len(cohort)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "5014c99f",
+   "metadata": {
+    "scrolled": false
+   },
+   "source": [
+    "cohort.hadm_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "11f5935f",
+   "metadata": {},
+   "source": [
+    "548.490833 / 24"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "87a1b947",
+   "metadata": {},
+   "source": [
+    "cohort.stay_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "f774a307",
+   "metadata": {},
+   "source": [
+    "265.649889 / 24"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "99d6a2f0",
+   "metadata": {},
+   "source": [
+    "cohort.len_selected.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "622669a8",
+   "metadata": {},
+   "source": [
+    "cohort_filtered = cohort[cohort.len_selected <= 1256.650000]\n",
+    "cohort_filtered"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "75ddbce8",
+   "metadata": {},
+   "source": [
+    "cohort_filtered.hadm_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "aaed6fb7",
+   "metadata": {},
+   "source": [
+    "cohort_filtered.stay_los.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "e2cca302",
+   "metadata": {},
+   "source": [
+    "cohort_filtered.len_selected.describe(percentiles=[.1, .25, .5, .75, .9, .95, .99])"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "bc283fd9",
+   "metadata": {},
+   "source": [
+    "all_patients = cohort_filtered.subject_id.unique()\n",
+    "len(all_patients)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "18aa0954",
+   "metadata": {},
+   "source": [
+    "from sklearn.model_selection import train_test_split\n",
+    "\n",
+    "\n",
+    "train_val_patients, test_patients = train_test_split(all_patients, test_size=0.1, random_state=42)\n",
+    "train_patients, val_patients = train_test_split(all_patients, test_size=0.111, random_state=42)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "d5d0a3ac",
+   "metadata": {},
+   "source": [
+    "print(train_patients.shape)\n",
+    "print(val_patients.shape)\n",
+    "print(test_patients.shape)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "57dc6d8f",
+   "metadata": {},
+   "source": [
+    "train = cohort_filtered[cohort_filtered.subject_id.isin(train_patients)].reset_index(drop=True)\n",
+    "val = cohort_filtered[cohort_filtered.subject_id.isin(val_patients)].reset_index(drop=True)\n",
+    "test = cohort_filtered[cohort_filtered.subject_id.isin(test_patients)].reset_index(drop=True)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "860d6a39",
+   "metadata": {},
+   "source": [
+    "print(train.shape)\n",
+    "print(val.shape)\n",
+    "print(test.shape)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "ce3c7f18",
+   "metadata": {},
+   "source": [
+    "train.to_csv(os.path.join(output_path, \"cohort_train.csv\"), index=False)\n",
+    "val.to_csv(os.path.join(output_path, \"cohort_val.csv\"), index=False)\n",
+    "test.to_csv(os.path.join(output_path, \"cohort_test.csv\"), index=False)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "7b0ea788",
+   "metadata": {},
+   "source": [
+    "_, test_subset_patients = train_test_split(test_patients, test_size=100, random_state=42)\n",
+    "len(test_subset_patients)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "5ba82c66",
+   "metadata": {},
+   "source": [
+    "test_subset = test[test.subject_id.isin(test_subset_patients)].groupby(\"subject_id\").apply(lambda x: x.sample(1)).reset_index(drop=True)\n",
+    "test_subset"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "source": "test_subset.len_selected.describe()",
+   "id": "0bd73438",
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "a58143a2",
+   "metadata": {},
+   "source": [
+    "print(test_subset.subject_id.nunique())\n",
+    "print(test_subset.hadm_id.nunique())"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "745f82b5",
+   "metadata": {},
+   "source": [
+    "test_subset.to_csv(os.path.join(output_path, \"cohort_test_subset.csv\"), index=False)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "7ba15669",
+   "metadata": {},
+   "source": [
+    "qa_note_test_subset = qa_note[qa_note.hadm_id.isin(test_subset.hadm_id.unique())]\n",
+    "qa_note_test_subset"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "040ce4fb",
+   "metadata": {},
+   "source": [
+    "qa_event_test_subset = qa_event[qa_event.hadm_id.isin(test_subset.hadm_id.unique())].groupby(\"hadm_id\").apply(lambda x: x.sample(1)).reset_index(drop=True)\n",
+    "qa_event_test_subset"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "5455dbca",
+   "metadata": {},
+   "source": [
+    "qa_event_test_subset.hadm_id.nunique()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "4c78aa45",
+   "metadata": {},
+   "source": [
+    "qa_event_test_subset.event_type.value_counts()"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "9942e684",
+   "metadata": {},
+   "source": [
+    "qa_test_subset = pd.concat([qa_event_test_subset, qa_note_test_subset]).reset_index(drop=True)\n",
+    "qa_test_subset"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "8366ae94",
+   "metadata": {},
+   "source": [
+    "qa_test_subset.to_csv(os.path.join(output_path, \"qa_test_subset.csv\"), index=False)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "410a34c3",
+   "metadata": {},
+   "source": [],
+   "outputs": [],
+   "execution_count": null
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "pytorch20",
+   "language": "python",
+   "name": "pytorch20"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.19"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}