755 lines (754 with data), 20.9 kB
{
"cells": [
{
"cell_type": "code",
"id": "bf6469fe",
"metadata": {},
"source": [
"import os\n",
"import sys\n",
"\n",
"src_path = os.path.abspath('../..')\n",
"print(src_path)\n",
"sys.path.append(src_path)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "fcd74d2c",
"metadata": {},
"source": [
"from src.utils import create_directory, raw_data_path, processed_data_path, set_seed"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "4cae42b3",
"metadata": {},
"source": [
"set_seed(seed=42)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "57e86fc8",
"metadata": {},
"source": [
"import pandas as pd"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "d5a4a2f7",
"metadata": {},
"source": [
"mimic_iv_path = os.path.join(raw_data_path, \"physionet.org/files/mimiciv/2.2\")\n",
"mimic_iv_note_path = os.path.join(raw_data_path, \"physionet.org/files/mimic-iv-note/2.2\")\n",
"output_path = os.path.join(processed_data_path, \"mimic4\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3d241540",
"metadata": {},
"source": [
"cohort = pd.read_csv(os.path.join(output_path, \"cohort.csv\"))\n",
"print(cohort.shape)\n",
"cohort.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "00486178",
"metadata": {},
"source": [
"cohort[\"hadm_intime\"] = pd.to_datetime(cohort[\"hadm_intime\"])\n",
"cohort[\"hadm_outtime\"] = pd.to_datetime(cohort[\"hadm_outtime\"])\n",
"cohort[\"stay_intime\"] = pd.to_datetime(cohort[\"stay_intime\"])\n",
"cohort[\"stay_outtime\"] = pd.to_datetime(cohort[\"stay_outtime\"])"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "64cc5546",
"metadata": {},
"source": [
"hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
"len(hadm_ids)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "93be6ec0",
"metadata": {},
"source": [
"from pandarallel import pandarallel\n",
"\n",
"pandarallel.initialize()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "3b1486b0",
"metadata": {},
"source": [
"## discharge"
]
},
{
"cell_type": "code",
"id": "f3c47a56",
"metadata": {},
"source": [
"discharge = pd.read_csv(os.path.join(mimic_iv_note_path, \"note/discharge.csv.gz\"))\n",
"print(discharge.shape)\n",
"discharge.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "88ceedbd",
"metadata": {},
"source": [
"discharge = discharge[discharge.hadm_id.isin(hadm_ids)]\n",
"print(discharge.shape)\n",
"discharge.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "78db45d9",
"metadata": {},
"source": [
"print(discharge.text.iloc[0])"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e98bc9d4",
"metadata": {},
"source": [
"# https://github.com/hanyin88/DRG-LLaMA/blob/main/data/MIMIC_Preprocessing.py\n",
"import re\n",
"\n",
"\n",
"def extract_HC(dc_summary_raw):\n",
"\n",
" # Set up the regular expression to extract hospital course from discharge summary\n",
" # Of note these patterns would not caputre all hospital courses, and is indeed a convservative approach to ensure quality of data\n",
" pattern1 = re.compile(\"Brief Hospital Course.*\\n*((?:\\n.*)+?)(Medications on Admission|___ on Admission|___ on Admission)\")\n",
" pattern2 = re.compile(\"Brief Hospital Course.*\\n*((?:\\n.*)+?)Discharge Medications\")\n",
" pattern3 = re.compile(\"(Brief Hospital Course|rief Hospital Course|HOSPITAL COURSE)\\\n",
" .*\\n*((?:\\n.*)+?)\\\n",
" (Medications on Admission|Discharge Medications|DISCHARGE MEDICATIONS|DISCHARGE DIAGNOSIS|Discharge Disposition|___ Disposition|CONDITION ON DISCHARGE|DISCHARGE INSTRUCTIONS)\")\n",
" pattern4 = re.compile(\"(Mg-[12].|LACTATE-[12].|Epi-|Gap-|COUNT-|TRF-)___(.*\\n*((?:\\n.*)+?))(Medications on Admission)\")\n",
"\n",
"\n",
" # Idea here is to try more convservaite pattern first, if not work, try less conservative pattern\n",
" def split_note(note):\n",
" if re.search(pattern1, note):\n",
" return re.search(pattern1, note).group(1)\n",
" elif re.search(pattern2, note):\n",
" return re.search(pattern2, note).group(1)\n",
" elif re.search(pattern3, note):\n",
" return re.search(pattern3, note).group(2)\n",
" elif re.search(pattern4, note):\n",
" return re.search(pattern4, note).group(2)\n",
" else:\n",
" return None\n",
"\n",
" # Apply the function to dc_summary_raw to extract hospital course\n",
" dc_summary_raw[\"hospital_course\"] = dc_summary_raw[\"text\"].apply(split_note)\n",
"\n",
" # Drop those records that do not have hospital course captured with above regular expression patterns\n",
" dc_summary = dc_summary_raw[[\"hadm_id\", \"text\", \"hospital_course\"]].dropna()\n",
"\n",
" # Get the number of words for each hospital course. Note that the current method is not accurate due to presense of special characters, but it's good enough for our purpose\n",
" dc_summary[\"num_words\"] = dc_summary[\"hospital_course\"].apply(lambda x: len(x.split()))\n",
"\n",
" # Remove the notes with less than 40 words\n",
" dc_summary = dc_summary[dc_summary[\"num_words\"] > 40]\n",
"\n",
" # Remove duplicate hospital courses (but keep the first one), as most of these notes represent low quality data\n",
" dc_summary = dc_summary.drop_duplicates(subset=[\"hospital_course\"], keep=\"first\")\n",
"\n",
" # Mean number of words in the hospital course is 378\n",
" dc_summary[\"num_words\"].mean()\n",
"\n",
" dc_summary = dc_summary[[\"hadm_id\", \"text\", \"hospital_course\"]]\n",
"\n",
" return dc_summary"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "0e542a5b",
"metadata": {},
"source": [
"discharge = extract_HC(discharge)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "1611f35d",
"metadata": {},
"source": [
"discharge"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "2ccccc52",
"metadata": {},
"source": [
"# https://github.com/ji-youn-kim/EHRNoteQA/blob/master/src/preprocessing/preprocess.py\n",
"import re\n",
"\n",
"\n",
"def transform_string(s):\n",
" s = re.sub(r'(\\n\\s*|\\s*\\n)', '\\n', s)\n",
" s = re.sub(r'\\s{2,}', ' ', s)\n",
" s = s.strip()\n",
" return s"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "dff0c4d6",
"metadata": {},
"source": [
"print(transform_string(discharge.hospital_course.iloc[1]))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "31cdb2ce",
"metadata": {},
"source": [
"discharge[\"cleaned_hospital_course\"] = discharge.hospital_course.parallel_apply(transform_string)\n",
"discharge.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "28d7f06b",
"metadata": {},
"source": [
"print(discharge.cleaned_hospital_course.iloc[4])"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "a5366fa3",
"metadata": {},
"source": [
"## GPT"
]
},
{
"cell_type": "code",
"id": "a50e9b87",
"metadata": {},
"source": [
"system_content = \"\"\"You are an AI assistant specialized in analyzing ICU patients' data.\n",
"\n",
"You are provided with a discharge summary of an ICU patient, which summarizes important clinical records and serves as an essential reference for the doctor’s clinical decision-making.\n",
"\n",
"Your task is to generate a question-answer pair inquiring about the patient. \n",
"\n",
"Objective:\n",
"1. Formulate one question that a doctor will ask based on the provided discharge summary.\n",
"2. The answer should be found within the provided discharge summary. \n",
"3. Refrain from formulating questions that can be answered without referring to the provided discharge summary.\n",
"4. Avoid questions that include sensitive personal information or \"___\".\n",
"5. Do not create questions that are too easy to answer. To answer your question, someone should have the clinical expertise equivalent to a doctor and must fully understand all provided discharge summaries.\n",
"6. Arrange your output in the following format:\n",
"- Question: [Your Question]\n",
"- Answer: [Your Answer]\n",
"7. Keep both the question and answer concise (within 256 tokens).\"\"\""
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "d6f04ea1",
"metadata": {},
"source": [
"import tiktoken\n",
"\n",
"\n",
"def num_tokens_from_message(message):\n",
" encoding = tiktoken.encoding_for_model(\"gpt-35-turbo-0125\")\n",
" return len(encoding.encode(message[0][\"content\"])) + len(encoding.encode(message[1][\"content\"])) + 11"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3be44396",
"metadata": {},
"source": [
"prompts = {}\n",
"for _, row in discharge.iterrows():\n",
" messages = [{\"role\": \"system\", \"content\": system_content},\n",
" {\"role\": \"user\", \"content\": row.cleaned_hospital_course}]\n",
" prompts[row.hadm_id] = messages\n",
"len(prompts)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "b8c9366c",
"metadata": {},
"source": [
"prompts[29079034]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "7ff9059c",
"metadata": {},
"source": [
"prompts_num_tokens = {}\n",
"for k, v in prompts.items():\n",
" prompts_num_tokens[k] = num_tokens_from_message(v)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "b11a428c",
"metadata": {},
"source": [
"import numpy as np\n",
"\n",
"\n",
"print(\"mean: \", np.mean(list(prompts_num_tokens.values())))\n",
"print(\"std: \", np.std(list(prompts_num_tokens.values())))\n",
"print(\"min: \", np.min(list(prompts_num_tokens.values())))\n",
"print(\"max: \", np.max(list(prompts_num_tokens.values())))\n",
"print(\"25th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 25))\n",
"print(\"50th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 50))\n",
"print(\"75th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 75))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e623c488",
"metadata": {},
"source": [
"max_response_tokens = 256\n",
"token_limit = 16384"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "1da8b4a8",
"metadata": {},
"source": [
"prompts_filtered = {}\n",
"for k, v in prompts.items():\n",
" if prompts_num_tokens[k] + max_response_tokens < token_limit:\n",
" prompts_filtered[k] = v\n",
" else:\n",
" print(f\"hadm id {k} is filtered due to length {prompts_num_tokens[k]}\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "7b0acb0f",
"metadata": {},
"source": [
"print(len(prompts))\n",
"print(len(prompts_filtered))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3f94580b",
"metadata": {},
"source": [
"prompts_filtered[29079034]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "6a21023c",
"metadata": {},
"source": [
"import asyncio\n",
"from openai import AsyncAzureOpenAI\n",
"\n",
"\n",
"# TODO: Enter your credentials\n",
"async_client = AsyncAzureOpenAI(\n",
" azure_endpoint=\"\",\n",
" api_key=\"\",\n",
" api_version=\"\"\n",
")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "5c77f0e6",
"metadata": {},
"source": [
"async def generate_chat_response(async_client, prompt):\n",
" chat_params = {\n",
" \"model\": \"gpt-3.5-turbo\",\n",
" \"messages\": prompt,\n",
" \"max_tokens\": max_response_tokens,\n",
" \"temperature\": 0.0,\n",
" }\n",
" try:\n",
" response = await async_client.chat.completions.create(**chat_params)\n",
" except Exception as e:\n",
" print(f\"Error in call_async: {e}\")\n",
" time.sleep(10)\n",
" print(f\"Sleep for 10s...\")\n",
" return -1\n",
" return response.choices[0].message.content"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "4aee59b7",
"metadata": {},
"source": [
"import time\n",
"\n",
"\n",
"async def process_prompts(prompts):\n",
" # Gather all the futures together and wait for them to complete\n",
" responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts)) \n",
" return responses"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "14fc6ecc",
"metadata": {},
"source": [
"len(prompts_filtered)"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"prompts_filtered_subset = {k: prompts_filtered[k] for k in list(prompts_filtered.keys())[:10]}\n",
"len(prompts_filtered_subset)"
],
"id": "7ed1aa1f",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e352d5df",
"metadata": {},
"source": [
"def chunk_list(lst, chunk_size):\n",
" \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
" for i in range(0, len(lst), chunk_size):\n",
" yield lst[i:i + chunk_size]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "c034d5e7",
"metadata": {},
"source": [
"from tqdm.asyncio import tqdm\n",
"\n",
"\n",
"async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
" all_responses = {}\n",
" \n",
" for i in range(repeat):\n",
" \n",
" print(f\"round {i}\")\n",
" prev_n_responses = len(all_responses)\n",
" \n",
" prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
"\n",
" # Chunk the prompts into batches\n",
" prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
"\n",
" for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
" batch_v = [prompts[k] for k in batch_k]\n",
" responses = await process_prompts(batch_v)\n",
" all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
" print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
" \n",
" return all_responses"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "80eedf1e",
"metadata": {},
"source": [
"# Choose an appropriate batch size\n",
"batch_size = 10 # Adjust based on your system and API limits\n",
"\n",
"# Assuming we are in an async environment\n",
"responses = await process_prompts_in_batches(prompts_filtered_subset, batch_size)\n",
"print(f\"Processed {len(responses)} responses\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "ee178c00",
"metadata": {},
"source": [
"responses"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "69b5ff64",
"metadata": {},
"source": [
"len(prompts_filtered)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "9dfc47c4",
"metadata": {},
"source": [
"# Choose an appropriate batch size\n",
"batch_size = 10 # Adjust based on your system and API limits\n",
"\n",
"# Assuming we are in an async environment\n",
"responses = await process_prompts_in_batches(prompts_filtered, batch_size)\n",
"print(f\"Processed {len(responses)} responses\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "5c76d81f",
"metadata": {},
"source": [
"responses[28369884]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "68e59ca1",
"metadata": {},
"source": [
"c = 0\n",
"for qa in responses.values():\n",
" if \"___\" in qa:\n",
" c += 1\n",
"c"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "a08d474e",
"metadata": {},
"source": [
"import re\n",
"\n",
"\n",
"def split_qa(qa, verbose=False):\n",
" if verbose:\n",
" print(qa)\n",
" pattern1 = r\"-?\\s*Question: (.*)\\s*-?\\s*Answer: (.*)\"\n",
" pattern2 = r\"\\*\\*Question:\\*\\* (.*)\\s*-?\\s*\\*\\*Answer:\\*\\* (.*)\"\n",
" match = re.search(pattern1, qa)\n",
" if match is None:\n",
" match = re.search(pattern2, qa)\n",
" question = match.group(1)\n",
" answer = match.group(2)\n",
" if verbose:\n",
" print(\"Question:\", question)\n",
" print(\"Answer:\", answer)\n",
" return question, answer"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "76f5f1e2",
"metadata": {},
"source": [
"responses_split = {}\n",
"for hadm_id, qa in responses.items():\n",
" try:\n",
" responses_split[hadm_id] = split_qa(qa)\n",
" except AttributeError:\n",
" print(qa)\n",
" print(\"=====================\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "efe61c28",
"metadata": {},
"source": [
"responses_split[28369884]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "04677935",
"metadata": {},
"source": [
"len(responses_split)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "10e1ff99",
"metadata": {},
"source": [
"import json\n",
"\n",
"\n",
"with open(os.path.join(output_path, \"qa_note_orig.jsonl\"), \"w\") as file:\n",
" for hadm_id, qa in responses.items():\n",
" # Convert the dictionary to a JSON string and write it to the file\n",
" json_string = json.dumps({\"hadm_id\": hadm_id, \"qa\": qa})\n",
" file.write(json_string + '\\n')"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "8a8aa622",
"metadata": {},
"source": [
"import json\n",
"\n",
"\n",
"with open(os.path.join(output_path, \"qa_note.jsonl\"), \"w\") as file:\n",
" for hadm_id, (q, a) in responses_split.items():\n",
" # Convert the dictionary to a JSON string and write it to the file\n",
" json_string = json.dumps({\"hadm_id\": hadm_id, \"q\": q, \"a\": a})\n",
" file.write(json_string + '\\n')"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "526f7d22",
"metadata": {},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch20",
"language": "python",
"name": "pytorch20"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}