355 lines (354 with data), 11.3 kB
{
"cells": [
{
"cell_type": "code",
"id": "afdeba94",
"metadata": {},
"source": [
"import os\n",
"import sys\n",
"\n",
"src_path = os.path.abspath(\"../..\")\n",
"print(src_path)\n",
"sys.path.append(src_path)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "0d5f2e19",
"metadata": {},
"source": "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e00815d2",
"metadata": {},
"source": [
"set_seed(seed=42)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "fd92d900",
"metadata": {},
"source": [
"import pandas as pd"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "model_path = os.path.join(remote_project_path, \"output\")",
"id": "4b426270718efb82",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "ef32981d",
"metadata": {},
"source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "539a6392",
"metadata": {},
"source": [
"cohort = pd.read_csv(os.path.join(output_path, \"cohort_test_subset.csv\"))\n",
"print(cohort.shape)\n",
"cohort.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "6659ecf8",
"metadata": {},
"source": [
"hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
"len(hadm_ids)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3f8cb6ae",
"metadata": {},
"source": [
"import logging\n",
"import os\n",
"\n",
"import pandas as pd\n",
"import torch\n",
"from torch.utils.data import Dataset\n",
"import re\n",
"\n",
"from src.utils import processed_data_path\n",
"\n",
"\n",
"class EvalInstructionTuningDataset(Dataset):\n",
" def __init__(self):\n",
" self.data_path = os.path.join(processed_data_path, f\"mimic4\")\n",
" qa = pd.read_csv(os.path.join(self.data_path, \"qa_test_subset.csv\"))\n",
" qa[\"source\"] = qa.event_type.apply(lambda x: \"note\" if pd.isna(x) else \"event\")\n",
" self.qa = qa\n",
" logging.warning(f\"Loaded {len(qa)} QA samples\")\n",
" \n",
" def _get_event_list(self, hadm_id):\n",
" df = pd.read_csv(os.path.join(self.data_path, f\"event_selected/event_{hadm_id}.csv\"))\n",
" event_list = []\n",
" for i, row in df.iterrows():\n",
" event_list.append((row.timestamp, row.event_type, row.event_value))\n",
" return event_list\n",
"\n",
" def _get_event_emb(self, hadm_id):\n",
" return torch.load(os.path.join(self.data_path, f\"pt_event_selected_no_time_type/event_{hadm_id}.pt\"))\n",
"\n",
" def __len__(self):\n",
" return len(self.qa)\n",
"\n",
" @staticmethod\n",
" def _extract_digits(event_tuple):\n",
" timestamp, event_type, event_value = event_tuple\n",
" try:\n",
" if event_type == \"patient demographics\":\n",
" value_match = re.search(r\"age:\\s*([\\d.]+)\", event_value)\n",
" if value_match:\n",
" value = float(value_match.group(1))\n",
" else:\n",
" value = 0\n",
" duration = 0\n",
" elif event_type == \"admission info\":\n",
" value, duration = 0, 0\n",
" elif event_type == \"diagnoses_icd\":\n",
" value, duration = 0, 0\n",
" elif event_type == \"labevents\":\n",
" value_match = re.search(r\":\\s*([\\d.]+)\", event_value)\n",
" if value_match:\n",
" value = float(value_match.group(1))\n",
" else:\n",
" value = 0\n",
" duration = 0\n",
" elif event_type == \"microbiologyevents\":\n",
" value, duration = 0, 0\n",
" elif event_type == \"prescriptions\":\n",
" value_match = re.search(r\"prescribed dose:\\s*([\\d.]+)\", event_value)\n",
" if value_match:\n",
" value = float(value_match.group(1))\n",
" else:\n",
" value = 0\n",
" duration_match = re.search(r\"duration:\\s*([\\d.]+)\", event_value)\n",
" if duration_match:\n",
" duration = float(duration_match.group(1))\n",
" else:\n",
" duration = 0\n",
" elif event_type == \"transfers\":\n",
" value, duration = 0, 0\n",
" elif event_type == \"procedureevents\":\n",
" value = 0\n",
" duration_match = re.search(r\"for\\s*([\\d.]+)\\s*hour\", event_value)\n",
" if duration_match:\n",
" duration = float(duration_match.group(1))\n",
" else:\n",
" duration = 0\n",
" else:\n",
" raise ValueError(f\"Unknown event type: {event_type}\")\n",
" except Exception as e:\n",
" value, duration = 0, 0\n",
" logging.warning(f\"Error {e} in extracting digits from event tuple: {event_tuple}\")\n",
" return value, duration\n",
"\n",
" def __getitem__(self, index):\n",
" data = self.qa.iloc[index]\n",
" q = data[\"q\"]\n",
" a = data[\"a\"]\n",
" source = data[\"source\"]\n",
" hadm_id = data[\"hadm_id\"]\n",
" event_emb = self._get_event_emb(data[\"hadm_id\"])\n",
" num_events = event_emb.shape[0]\n",
" event_list = self._get_event_list(data[\"hadm_id\"])\n",
" assert len(event_list) == num_events\n",
" time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)\n",
" value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)\n",
" event_emb = torch.cat(\n",
" [\n",
" event_emb,\n",
" time_tensor,\n",
" value_duration_tensor,\n",
" ],\n",
" dim=1\n",
" )\n",
" final_q = \"\\n\".join([\"<image>\" * num_events, q])\n",
" return final_q, a, event_emb, source, hadm_id"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "8d5594cb",
"metadata": {},
"source": [
"dataset = EvalInstructionTuningDataset()\n",
"q, a, event_emb, source, hadm_id = dataset[0]\n",
"print(q)\n",
"print(a)\n",
"print(source)\n",
"print(hadm_id)\n",
"print(event_emb.shape)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "241e1241",
"metadata": {},
"source": [
"from src.model.modeling_llemr import LlemrForConditionalGeneration\n",
"from src.model.init_llemr import init_llemr\n",
"from transformers import AutoTokenizer\n",
"from src.model.modeling_dummy import DummyModel\n",
"from peft import PeftModel\n",
"\n",
"device = \"cuda:0\"\n",
"llm_pretrained_model_name_or_path = \"lmsys/vicuna-7b-v1.5\"\n",
"lora_name_or_path = \"zzachw12/llemr-v1\"\n",
"model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, 1027)\n",
"model.to(torch.bfloat16)\n",
"model = PeftModel.from_pretrained(model, lora_name_or_path)\n",
"model.to(device)\n",
"model.eval()\n",
"sys_prompt = \"You are an AI assistant specialized in analyzing ICU patient data.\""
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "bfd7ff8a",
"metadata": {},
"source": [
"model.dtype"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "19a04f7d",
"metadata": {},
"source": [
"from tqdm import tqdm\n",
"\n",
"\n",
"all_responses = {}\n",
"for q, a, event_emb, source, hadm_id in tqdm(dataset):\n",
" message = [\n",
" {\"role\": \"system\", \"content\": sys_prompt},\n",
" {\"role\": \"user\", \"content\": q},\n",
" ]\n",
" message = tokenizer.apply_chat_template(\n",
" message,\n",
" tokenize=False,\n",
" add_generation_prompt=True\n",
" )\n",
" inputs = tokenizer(\n",
" message,\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" truncation=True,\n",
" add_special_tokens=False,\n",
" )\n",
" inputs = inputs.to(device)\n",
" event_emb = event_emb.unsqueeze(1).to(device)\n",
" outputs = model.generate(\n",
" input_ids=inputs[\"input_ids\"],\n",
" attention_mask=inputs[\"attention_mask\"],\n",
" pixel_values=event_emb,\n",
" max_new_tokens=256\n",
" )\n",
" generated_text = tokenizer.decode(outputs[0][len(inputs[\"input_ids\"][0]):], skip_special_tokens=True)\n",
" all_responses[(source, hadm_id)] = generated_text"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "a72e85e8",
"metadata": {},
"source": [
"print(f\"Processed {len(all_responses)} responses\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "c4cfc894",
"metadata": {},
"source": "create_directory(os.path.join(model_path, \"llemr_vicuna/qa_output\"))",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "7e65eb22",
"metadata": {},
"source": [
"import json\n",
"\n",
"\n",
"with open(os.path.join(model_path, \"llemr_vicuna/qa_output/answer.jsonl\"), \"w\") as file:\n",
" for _, data in dataset.qa.iterrows():\n",
" a_hat = all_responses.get((data.source, data.hadm_id), \"\")\n",
" json_string = json.dumps({\"hadm_id\": data.hadm_id, \"q\": data.q, \"a\": data.a, \"a_hat\": a_hat, \"source\": data.source})\n",
" file.write(json_string + '\\n')"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e4424b6a",
"metadata": {},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "llm",
"language": "python",
"name": "llm"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}