[780764]: / src / preprocess / 04_template_qa_event.ipynb

Download this file

1143 lines (1142 with data), 42.8 kB

{
 "cells": [
  {
   "cell_type": "code",
   "id": "bf6469fe",
   "metadata": {},
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "src_path = os.path.abspath('../..')\n",
    "print(src_path)\n",
    "sys.path.append(src_path)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fcd74d2c",
   "metadata": {},
   "source": [
    "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4cae42b3",
   "metadata": {},
   "source": [
    "set_seed(seed=42)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "57e86fc8",
   "metadata": {},
   "source": [
    "import pandas as pd"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d5a4a2f7",
   "metadata": {},
   "source": [
    "mimic_iv_path = os.path.join(raw_data_path, \"physionet.org/files/mimiciv/2.2\")\n",
    "mimic_iv_note_path = os.path.join(raw_data_path, \"physionet.org/files/mimic-iv-note/2.2\")\n",
    "output_path = os.path.join(processed_data_path, \"mimic4\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3d241540",
   "metadata": {},
   "source": [
    "cohort = pd.read_csv(os.path.join(output_path, \"cohort.csv\"))\n",
    "print(cohort.shape)\n",
    "cohort.head()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "00486178",
   "metadata": {},
   "source": [
    "cohort[\"hadm_intime\"] = pd.to_datetime(cohort[\"hadm_intime\"])\n",
    "cohort[\"hadm_outtime\"] = pd.to_datetime(cohort[\"hadm_outtime\"])\n",
    "cohort[\"stay_intime\"] = pd.to_datetime(cohort[\"stay_intime\"])\n",
    "cohort[\"stay_outtime\"] = pd.to_datetime(cohort[\"stay_outtime\"])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "64cc5546",
   "metadata": {},
   "source": [
    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
    "len(hadm_ids)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "3b1486b0",
   "metadata": {},
   "source": [
    "## event"
   ]
  },
  {
   "cell_type": "code",
   "id": "78028fd9",
   "metadata": {},
   "source": [
    "hadm_id_to_max_hours = cohort[[\"hadm_id\", \"hadm_los\"]].set_index(\"hadm_id\").to_dict()[\"hadm_los\"]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f5352959",
   "metadata": {},
   "source": [
    "import random"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "433c240c",
   "metadata": {},
   "source": [
    "def read_event_df(hadm_id, event_type):\n",
    "    df = pd.read_csv(os.path.join(output_path, f\"event_{event_type}/event_{hadm_id}.csv\"))\n",
    "    return df"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b4832885",
   "metadata": {},
   "source": [
    "hadm_id = 28141610"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a0238627",
   "metadata": {},
   "source": [
    "hadm_id_to_max_hours[hadm_id]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6974f963",
   "metadata": {},
   "source": [
    "def generate_qa_patient_demographics(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"patient_demographics\")\n",
    "    assert len(df) == 1\n",
    "    available_qa = []\n",
    "    available_qa.append((\"What was the gender of the patient?\", df.iloc[0].meta_gender))\n",
    "    available_qa.append((\"What was the age of the patient?\", str(df.iloc[0].meta_age)))    \n",
    "    available_qa.append((\"What was the race of the patient?\", df.iloc[0].meta_race))\n",
    "    available_qa.append((\"What was the insurance of the patient?\", df.iloc[0].meta_insurance))\n",
    "    if not pd.isna(df.iloc[0].meta_marital_status):\n",
    "        available_qa.append((\"What was the marital status of the patient?\", df.iloc[0].meta_marital_status))\n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a79c3d0f",
   "metadata": {},
   "source": [
    "generate_qa_patient_demographics(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ee200179",
   "metadata": {},
   "source": [
    "def generate_qa_admission_info(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"admission_info\")\n",
    "    assert len(df) == 1\n",
    "    available_qa = []\n",
    "    available_qa.append((\"What was the admission type of the patient?\", df.iloc[0].meta_admission_type))\n",
    "    available_qa.append((\"What was the admission location of the patient?\", df.iloc[0].meta_admission_location))    \n",
    "    if not pd.isna(df.iloc[0].meta_chief_complaint):\n",
    "        available_qa.append((\"What aws the chief complaint of the patient?\", df.iloc[0].meta_chief_complaint))\n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c697d217",
   "metadata": {},
   "source": [
    "generate_qa_admission_info(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "08e3f593",
   "metadata": {},
   "source": [
    "def generate_qa_diagnoses_icd(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"diagnoses_icd\")\n",
    "    available_qa = []\n",
    "    available_qa.append((\"What were the billled diagnoses of the patient?\", \n",
    "                         \"; \".join(df.meta_long_title.tolist())))\n",
    "    available_qa.append((\"What was the first billled diagnose of the patient?\", \n",
    "                         \"; \".join(df.meta_long_title.tolist()[:1])))\n",
    "    available_qa.append((\"What were the top three billled diagnoses of the patient?\", \n",
    "                         \"; \".join(df.meta_long_title.tolist()[:3])))\n",
    "    available_qa.append((\"What were the top five billled diagnoses of the patient?\", \n",
    "                         \"; \".join(df.meta_long_title.tolist()[:5])))\n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d4f8e4ea",
   "metadata": {},
   "source": [
    "generate_qa_diagnoses_icd(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "166c4663",
   "metadata": {},
   "source": [
    "import warnings\n",
    "\n",
    "\n",
    "def sample_time_period(df, enforce_non_empty=True):\n",
    "#     this is wrong since the df is only a sub-sequence\n",
    "#     max_hours = df.timestamp.max() \n",
    "    max_hours = hadm_id_to_max_hours[df.hadm_id.iloc[0]]\n",
    "    max_days = int(max_hours // 24 + 1)\n",
    "    \n",
    "    available = {\n",
    "        \"during the first 12 hours\": lambda x: x < 12,\n",
    "        \"during the first 24 hours\": lambda x: x < 24,\n",
    "        \"during the first 48 hours\": lambda x: x < 48,\n",
    "        \"during the last 12 hours\": lambda x: x >= max_hours - 12,\n",
    "        \"during the last 24 hours\": lambda x: x >= max_hours - 24,\n",
    "        \"during the last 48 hours\": lambda x: x >= max_hours - 48,\n",
    "    }\n",
    "    for _ in range(3):\n",
    "        d = random.choice(range(1, max_days + 1))\n",
    "        available[f\"during day {d}\"] = lambda x, d=d: ((d - 1) * 24) <= x < (d * 24)\n",
    "    n_tries = 0\n",
    "    while True:\n",
    "        s = random.choice(list(available.keys()))\n",
    "        f = available[s]\n",
    "        tmp = df[df.timestamp.apply(f)]\n",
    "        if not enforce_non_empty:\n",
    "            return s, tmp\n",
    "        if len(tmp) > 0:\n",
    "            return s, tmp        \n",
    "        n_tries += 1\n",
    "        if n_tries > 100:\n",
    "            warnings.warn(f\"Too many tries to enfore non-empty return: len={len(df)}\")\n",
    "            return \"during the entire stay\", df"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "250a3b71",
   "metadata": {},
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def generate_qa_labevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"labevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    df[\"event\"] = df.apply(lambda x: f\"{x.meta_fluid} {x.meta_label} {x.meta_category}\", axis=1)\n",
    "    df[\"value\"] = df.apply(lambda x: f\"{x.meta_value}\" if pd.isna(x.meta_valueuom) else f\"{x.meta_value} {x.meta_valueuom}\", axis=1)\n",
    "    \n",
    "    x = df[df.event == random.choice(df.event.unique())].sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?    \n",
    "    q = f\"What was the {x.event} measurement at the {x.timestamp:.2f} hour?\"\n",
    "    a = x.value\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # was the the {event} at the {time_exact} hour normal?\n",
    "    q = f\"Was the {x.event} measurement at the {x.timestamp:.2f} hour normal?\"\n",
    "    if pd.isna(x.meta_flag):\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    # what {measurements} were performed on the {specimen} at the {time_exact} hour?\n",
    "    df_tmp = df[(df.timestamp == x.timestamp) & (df.meta_fluid == x.meta_fluid) & (df.meta_category == x.meta_category)]\n",
    "    q = f\"What {x.meta_category} measurements were performed on the {x.meta_fluid} specimen at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_label.tolist())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what {measurements} performed on the {specimen} were abnormal at the {time_exact} hour?\n",
    "    df_tmp = df_tmp[~pd.isna(df_tmp.meta_flag)]\n",
    "    q = f\"What {x.meta_category} measurements on the {x.meta_fluid} specimen were abnormal at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_label.tolist())\n",
    "    if len(a) == 0:\n",
    "        a = \"All lab tests were normal\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.event == x.event]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "    q = f\"What was the first {x.event} measurement {s}?\"\n",
    "    a = df_tmp.iloc[0].value\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"What was the last {x.event} measurement {s}?\"\n",
    "    a = df_tmp.iloc[-1].value\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first {x.event} measurement {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last {x.event} measurement {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient have the {x.event} measurement {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {agg_function} {value} of the {event} during {time_period}?\n",
    "    unique_units = df_tmp.meta_valueuom.unique()\n",
    "    if len(unique_units) == 1 and not pd.isna(unique_units[0]):\n",
    "        try:\n",
    "            unique_values = [float(v) for v in df_tmp.meta_value]\n",
    "            q = f\"What was the maximum {x.event} measurement {s}?\"\n",
    "            a = f\"{np.max(unique_values)} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "            q = f\"What was the minimum {x.event} measurement {s}?\"\n",
    "            a = f\"{np.min(unique_values)} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "            q = f\"What was the average {x.event} measurement {s}?\"\n",
    "            a = f\"{np.mean(unique_values):.2f} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "        except ValueError:\n",
    "            pass\n",
    "        \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Did the patient have any {x.event} measurement {s}?\"\n",
    "    if x.event in df_tmp.event.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "43f51535",
   "metadata": {},
   "source": [
    "generate_qa_labevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "2037b993",
   "metadata": {},
   "source": [
    "def generate_qa_microbiologyevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"microbiologyevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "        \n",
    "    # what {measurements} were performed on the {specimen} at the {time_exact} hour?\n",
    "    df_tmp = df[(df.timestamp == x.timestamp) & (df.meta_spec_type_desc == x.meta_spec_type_desc)]\n",
    "    q = f\"What microbiology tests were performed on the {x.meta_spec_type_desc} specimen at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_test_name.unique())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what {measurements} performed on the {specimen} were abnormal at the {time_exact} hour?\n",
    "    df_tmp = df_tmp[~pd.isna(df_tmp.meta_org_name)]\n",
    "    q = f\"What organisms were found on the {x.meta_spec_type_desc} specimen at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_org_name.unique())\n",
    "    if len(a) == 0:\n",
    "        a = \"No growth was found\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what {measurements} performed on the {specimen} were abnormal at the {time_exact} hour?\n",
    "    if len(df_tmp) > 0:\n",
    "        x = df_tmp.sample().iloc[0]\n",
    "        df_tmp = df_tmp[df_tmp.meta_org_name == x.meta_org_name]\n",
    "        q = f\"What were the antibiotics test results against the {x.meta_org_name} on the {x.meta_spec_type_desc} specimen at the {x.timestamp:.2f} hour?\"\n",
    "        df_tmp = df_tmp[~pd.isna(df_tmp.meta_ab_name)]\n",
    "        a = \", \".join([f\"{ab}: {res}\" for ab, res in zip(df_tmp.meta_ab_name.tolist(), df_tmp.meta_interpretation.tolist())])\n",
    "        if len(a) == 0:\n",
    "            a = \"No antibiotics was tested\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Did the patient have any microbiology test on the {x.meta_spec_type_desc} specimen {s}?\"\n",
    "    if x.meta_spec_type_desc in df_tmp.meta_spec_type_desc.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f66b166b",
   "metadata": {},
   "source": [
    "hadm_id = 24248394"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "25258690",
   "metadata": {},
   "source": [
    "generate_qa_microbiologyevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7133c917",
   "metadata": {},
   "source": [
    "def generate_qa_prescriptions(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"prescriptions\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?\n",
    "    if not pd.isna(x.meta_prod_strength):\n",
    "        q = f\"What was the composition of the prescribed {x.meta_drug} at the {x.timestamp:.2f} hour?\"\n",
    "        a = x.meta_prod_strength\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    if not pd.isna(x.meta_dose_val_rx):\n",
    "        q = f\"What was the dose of the prescribed {x.meta_drug} at the {x.timestamp:.2f} hour?\"\n",
    "        a = f\"{x.meta_dose_val_rx}\"\n",
    "        if not pd.isna(x.meta_dose_unit_rx):\n",
    "            a += f\" {x.meta_dose_unit_rx}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    q = f\"What was the administration route of the prescribed {x.meta_drug} at the {x.timestamp:.2f} hour?\"\n",
    "    a = x.meta_route\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"What was the administration duration of the prescribed {x.meta_drug} at the {x.timestamp:.2f} hour?\"\n",
    "    a = f\"{x.meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    # what drugs were prescribed at the {time_exact} hour?\n",
    "    df_tmp = df[(df.timestamp == x.timestamp)]\n",
    "    q = f\"What drugs were prescribed at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_drug.unique())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.meta_drug == x.meta_drug]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "    if not pd.isna(x.meta_prod_strength):\n",
    "        q = f\"What was the composition of the first prescribed {x.meta_drug} {s}?\"\n",
    "        a = df_tmp.iloc[0].meta_prod_strength\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "        q = f\"What was the composition of the last prescribed {x.meta_drug} {s}?\"\n",
    "        a = df_tmp.iloc[-1].meta_prod_strength\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    if not pd.isna(x.meta_dose_val_rx):\n",
    "        q = f\"What was the dose of the first prescribed {x.meta_drug} {s}?\"\n",
    "        a = f\"{df_tmp.iloc[0].meta_dose_val_rx}\"\n",
    "        if not pd.isna(df_tmp.iloc[0].meta_dose_unit_rx):\n",
    "            a += f\" {df_tmp.iloc[0].meta_dose_unit_rx}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "        q = f\"What was the dose of the last prescribed {x.meta_drug} {s}?\"\n",
    "        a = f\"{df_tmp.iloc[-1].meta_dose_val_rx}\"\n",
    "        if not pd.isna(df_tmp.iloc[-1].meta_dose_unit_rx):\n",
    "            a += f\" {df_tmp.iloc[-1].meta_dose_unit_rx}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    q = f\"What was the administration route of the first prescribed {x.meta_drug} {s}?\"\n",
    "    a = df_tmp.iloc[0].meta_route\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"What was the administration route of the last prescribed {x.meta_drug} {s}?\"\n",
    "    a = df_tmp.iloc[-1].meta_route\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"What was the administration duration of the first prescribed {x.meta_drug} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[0].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"What was the administration duration of the last prescribed {x.meta_drug} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[-1].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first {x.meta_drug} prescription {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last {x.meta_drug} prescription {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient have the {x.meta_drug} prescription {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Was the patient prescribed with any {x.meta_drug} {s}?\"\n",
    "    if x.meta_drug in df_tmp.meta_drug.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d399e7b6",
   "metadata": {},
   "source": [
    "hadm_id = 27262979"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d4606917",
   "metadata": {},
   "source": [
    "generate_qa_prescriptions(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7856ae48",
   "metadata": {},
   "source": [
    "def generate_qa_transfers(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"transfers\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?\n",
    "    q = f\"Which unit was the patient transferred to at the {x.timestamp:.2f} hour?\"\n",
    "    if not pd.isna(x.meta_careunit):\n",
    "        a = x.meta_careunit\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "    # when was the patient discharged from the hospital?\n",
    "    q = f\"When was the patient discharged from the hospital?\"\n",
    "    df_tmp = df[df.event_value == \"discharge\"]\n",
    "    assert len(df_tmp) == 1\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # how long was the length of hospital stay of the patient?\n",
    "    q = f\"How long was the length of hospital stay of the patient?\"\n",
    "    a = f\"{df_tmp.iloc[0].timestamp:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"How long was the length of hospital stay of the patient?\"\n",
    "    a = f\"{df_tmp.iloc[0].timestamp / 24:.2f} days\"\n",
    "    available_qa.append((q, a))\n",
    "\n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "95240755",
   "metadata": {},
   "source": [
    "hadm_id = 29622279"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "340851db",
   "metadata": {},
   "source": [
    "generate_qa_transfers(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7c09d3aa",
   "metadata": {},
   "source": [
    "def generate_qa_chartevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"chartevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    df[\"event\"] = df.meta_label\n",
    "    df[\"value\"] = df.apply(lambda x: f\"{x.meta_value}\" if pd.isna(x.meta_valueuom) else f\"{x.meta_value} {x.meta_valueuom}\", axis=1)\n",
    "    \n",
    "    x = df[df.event == random.choice(df.event.unique())].sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?    \n",
    "    q = f\"What was the {x.event} value at the {x.timestamp:.2f} hour?\"\n",
    "    a = x.value\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.event == x.event]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "    q = f\"What was the first {x.event} value {s}?\"\n",
    "    a = df_tmp.iloc[0].value\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"What was the last {x.event} value {s}?\"\n",
    "    a = df_tmp.iloc[-1].value\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first {x.event} value {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last {x.event} value {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient have the {x.event} value {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {agg_function} {value} of the {event} during {time_period}?\n",
    "    unique_units = df_tmp.meta_valueuom.unique()\n",
    "    if len(unique_units) == 1 and not pd.isna(unique_units[0]):\n",
    "        try:\n",
    "            unique_values = [float(v) for v in df_tmp.meta_value]\n",
    "            q = f\"What was the maximum {x.event} value {s}?\"\n",
    "            a = f\"{np.max(unique_values)} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "            q = f\"What was the minimum {x.event} value {s}?\"\n",
    "            a = f\"{np.min(unique_values)} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "            q = f\"What was the average {x.event} value {s}?\"\n",
    "            a = f\"{np.mean(unique_values):.2f} {unique_units[0]}\"\n",
    "            available_qa.append((q, a))\n",
    "        except ValueError:\n",
    "            pass\n",
    "        \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Did the patient have any {x.event} value {s}?\"\n",
    "    if x.event in df_tmp.event.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "40cfc8c7",
   "metadata": {},
   "source": [
    "generate_qa_chartevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "df99155c",
   "metadata": {},
   "source": [
    "def generate_qa_inputevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"inputevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?\n",
    "    q = f\"What was the amount of the IV administration {x.meta_label} at the {x.timestamp:.2f} hour?\"\n",
    "    a = f\"{x.meta_amount:.2f} {x.meta_amountuom}\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    q = f\"What was the duration of IV administration {x.meta_label} at the {x.timestamp:.2f} hour?\"\n",
    "    a = f\"{x.meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    # what drugs were prescribed at the {time_exact} hour?\n",
    "    df_tmp = df[(df.timestamp == x.timestamp)]\n",
    "    q = f\"What drugs were administered through IV at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_label.unique())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.meta_label == x.meta_label]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "        \n",
    "    q = f\"What was the amount of the first IV administration {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[0].meta_amount:.2f} {df_tmp.iloc[0].meta_amountuom}\"\n",
    "    available_qa.append((q, a))\n",
    "\n",
    "    q = f\"What was the amount of the last IV administration {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[-1].meta_amount:.2f} {df_tmp.iloc[-1].meta_amountuom}\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    q = f\"What was the duration of the first IV administration {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[0].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "\n",
    "    q = f\"What was the duration of the last IV administration {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[-1].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first {x.meta_label} IV administration {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last {x.meta_label} IV administration {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient have the {x.meta_label} IV administration {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Was the patient administered with any {x.meta_label} through IV {s}?\"\n",
    "    if x.meta_label in df_tmp.meta_label.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a3faa4d4",
   "metadata": {},
   "source": [
    "hadm_id = 24248394"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "386804c8",
   "metadata": {},
   "source": [
    "generate_qa_inputevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "cb95d057",
   "metadata": {},
   "source": [
    "def generate_qa_outputevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"outputevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what was the {value} of the {event} at the {time_exact} hour?\n",
    "    q = f\"What was the amount of the output {x.meta_label} at the {x.timestamp:.2f} hour?\"\n",
    "    a = f\"{x.meta_value:.2f} {x.meta_valueuom}\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.meta_label == x.meta_label]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "        \n",
    "    q = f\"What was the amount of the first output {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[0].meta_value:.2f} {df_tmp.iloc[0].meta_valueuom}\"\n",
    "    available_qa.append((q, a))\n",
    "\n",
    "    q = f\"What was the amount of the last output {x.meta_label} {s}?\"\n",
    "    a = f\"{df_tmp.iloc[-1].meta_value:.2f} {df_tmp.iloc[-1].meta_valueuom}\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    unique_units = df_tmp.meta_valueuom.unique()\n",
    "    if len(unique_units) == 1:\n",
    "        q = f\"What was the total amount of the output {x.meta_label} {s}?\"\n",
    "        a = f\"{df_tmp.meta_value.sum():.2f} {unique_units[0]}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "        q = f\"What was the maximum amount of the output {x.meta_label} {s}?\"\n",
    "        a = f\"{df_tmp.meta_value.max():.2f} {unique_units[0]}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "        q = f\"What was the minimum amount of the output {x.meta_label} {s}?\"\n",
    "        a = f\"{df_tmp.meta_value.min():.2f} {unique_units[0]}\"\n",
    "        available_qa.append((q, a))\n",
    "        \n",
    "        q = f\"What was the average amount of the output {x.meta_label} {s}?\"\n",
    "        a = f\"{df_tmp.meta_value.mean():.2f} {unique_units[0]}\"\n",
    "        available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first output {x.meta_label} {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last output {x.meta_label} {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient have the {x.meta_label} output {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Did the patient have any {x.meta_label} output {s}?\"\n",
    "    if x.meta_label in df_tmp.meta_label.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c1caa229",
   "metadata": {},
   "source": [
    "generate_qa_outputevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "1733836b",
   "metadata": {},
   "source": [
    "def generate_qa_procedureevents(hadm_id, return_one=True):\n",
    "    df = read_event_df(hadm_id, \"procedureevents\")\n",
    "    available_qa = []\n",
    "    \n",
    "    x = df.sample()\n",
    "    x = x.iloc[0]\n",
    "    \n",
    "    # what procedures were performed at the {time_exact} hour?\n",
    "    df_tmp = df[(df.timestamp == x.timestamp)]\n",
    "    q = f\"What procedures were performed at the {x.timestamp:.2f} hour?\"\n",
    "    a = \", \".join(df_tmp.meta_label.unique())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what procedures were performed during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df)\n",
    "    q = f\"What procedures were performed {s}?\"\n",
    "    a = \", \".join(df_tmp.meta_label.unique())\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # what was the {value} of the {time_select} {event} during {time_period}?\n",
    "    df_tmp = df[df.meta_label == x.meta_label]\n",
    "    s, df_tmp = sample_time_period(df_tmp)\n",
    "        \n",
    "    q = f\"What was the duration of the first {x.meta_label} procedure {s}?\"\n",
    "    a = f\"{df_tmp.iloc[0].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "\n",
    "    q = f\"What was the duration of the last {x.meta_label} procedure {s}?\"\n",
    "    a = f\"{df_tmp.iloc[-1].meta_duration:.2f} hours\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # when was the {time_select} {event} during {time_period}?\n",
    "    q = f\"When was the first {x.meta_label} procedure {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[0].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    q = f\"When was the last {x.meta_label} procedure {s}?\"\n",
    "    a = f\"At the {df_tmp.iloc[-1].timestamp:.2f} hour\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # count the number of times the patient had {event} during {time_period}?\n",
    "    q = f\"How many times did the patient undergo the {x.meta_label} procedure {s}?\"\n",
    "    a = f\"{len(df_tmp)} times\" if len(df_tmp) > 1 else \"1 time\"\n",
    "    available_qa.append((q, a))\n",
    "    \n",
    "    # was the patient having {event} during {time_period}?\n",
    "    s, df_tmp = sample_time_period(df, enforce_non_empty=False)\n",
    "    q = f\"Did the patient undergo any {x.meta_label} procedure {s}?\"\n",
    "    if x.meta_label in df_tmp.meta_label.unique():\n",
    "        a = \"Yes\"\n",
    "    else:\n",
    "        a = \"No\"\n",
    "    available_qa.append((q, a))\n",
    "        \n",
    "    if return_one:\n",
    "        return random.choice(available_qa)\n",
    "    else:\n",
    "        return available_qa"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "2431bbc6",
   "metadata": {},
   "source": [
    "generate_qa_procedureevents(hadm_id, return_one=False)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "df04ca7e",
   "metadata": {},
   "source": [
    "events_selected = {\n",
    "    \"patient_demographics\": generate_qa_patient_demographics,\n",
    "    \"admission_info\": generate_qa_admission_info,\n",
    "    \"diagnoses_icd\": generate_qa_diagnoses_icd,\n",
    "    \"labevents\": generate_qa_labevents,           \n",
    "    \"procedureevents\": generate_qa_procedureevents,\n",
    "    \"microbiologyevents\": generate_qa_microbiologyevents,\n",
    "    \"prescriptions\": generate_qa_prescriptions,  \n",
    "    \"transfers\": generate_qa_transfers,\n",
    "}"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f05a006b",
   "metadata": {},
   "source": [
    "len(hadm_ids)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "2ef88427",
   "metadata": {},
   "source": [
    "350000 // 59513"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "613c47b6",
   "metadata": {},
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "qa = []\n",
    "for hadm_id in tqdm(hadm_ids):\n",
    "    for _ in range(6):\n",
    "        while True:\n",
    "            event_type = random.choice(list(events_selected.keys()))\n",
    "            event_f = events_selected[event_type]\n",
    "            try:\n",
    "                qa.append((hadm_id, *event_f(hadm_id), event_type))\n",
    "                break\n",
    "            except FileNotFoundError:\n",
    "                warnings.warn(f\"No {event_type} for {hadm_id}\")\n",
    "                continue\n",
    "#             except Exception as e:\n",
    "#                 print(f\"Got the following error with {event_type}({hadm_id}): \", e)\n",
    "#                 continue\n",
    "#     if len(qa) >= 1000:\n",
    "#         break\n",
    "        \n",
    "print(f\"Processed {len(qa)} responses\")\n",
    "\n",
    "warnings.filterwarnings('default')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "import json\n",
    "\n",
    "\n",
    "with open(os.path.join(output_path, \"qa_event_template.jsonl\"), \"w\") as file:\n",
    "    for hadm_id, q, a, e in qa:\n",
    "        # Convert the dictionary to a JSON string and write it to the file\n",
    "        json_string = json.dumps({\"hadm_id\": hadm_id, \"q\": q, \"a\": a, \"event_type\": e})\n",
    "        file.write(json_string + '\\n')"
   ],
   "id": "52bbb4e5",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "b2512fb3b6287bb7"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch20",
   "language": "python",
   "name": "pytorch20"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}