[780764]: / src / eval / query_llemr.ipynb

Download this file

355 lines (354 with data), 11.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "id": "afdeba94",
   "metadata": {},
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "src_path = os.path.abspath(\"../..\")\n",
    "print(src_path)\n",
    "sys.path.append(src_path)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "0d5f2e19",
   "metadata": {},
   "source": "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e00815d2",
   "metadata": {},
   "source": [
    "set_seed(seed=42)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fd92d900",
   "metadata": {},
   "source": [
    "import pandas as pd"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "model_path = os.path.join(remote_project_path, \"output\")",
   "id": "4b426270718efb82",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ef32981d",
   "metadata": {},
   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "539a6392",
   "metadata": {},
   "source": [
    "cohort = pd.read_csv(os.path.join(output_path, \"cohort_test_subset.csv\"))\n",
    "print(cohort.shape)\n",
    "cohort.head()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6659ecf8",
   "metadata": {},
   "source": [
    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
    "len(hadm_ids)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3f8cb6ae",
   "metadata": {},
   "source": [
    "import logging\n",
    "import os\n",
    "\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "import re\n",
    "\n",
    "from src.utils import processed_data_path\n",
    "\n",
    "\n",
    "class EvalInstructionTuningDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        self.data_path = os.path.join(processed_data_path, f\"mimic4\")\n",
    "        qa = pd.read_csv(os.path.join(self.data_path, \"qa_test_subset.csv\"))\n",
    "        qa[\"source\"] = qa.event_type.apply(lambda x: \"note\" if pd.isna(x) else \"event\")\n",
    "        self.qa = qa\n",
    "        logging.warning(f\"Loaded {len(qa)} QA samples\")\n",
    "    \n",
    "    def _get_event_list(self, hadm_id):\n",
    "        df = pd.read_csv(os.path.join(self.data_path, f\"event_selected/event_{hadm_id}.csv\"))\n",
    "        event_list = []\n",
    "        for i, row in df.iterrows():\n",
    "            event_list.append((row.timestamp, row.event_type, row.event_value))\n",
    "        return event_list\n",
    "\n",
    "    def _get_event_emb(self, hadm_id):\n",
    "        return torch.load(os.path.join(self.data_path, f\"pt_event_selected_no_time_type/event_{hadm_id}.pt\"))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.qa)\n",
    "\n",
    "    @staticmethod\n",
    "    def _extract_digits(event_tuple):\n",
    "        timestamp, event_type, event_value = event_tuple\n",
    "        try:\n",
    "            if event_type == \"patient demographics\":\n",
    "                value_match = re.search(r\"age:\\s*([\\d.]+)\", event_value)\n",
    "                if value_match:\n",
    "                    value = float(value_match.group(1))\n",
    "                else:\n",
    "                    value = 0\n",
    "                duration = 0\n",
    "            elif event_type == \"admission info\":\n",
    "                value, duration = 0, 0\n",
    "            elif event_type == \"diagnoses_icd\":\n",
    "                value, duration = 0, 0\n",
    "            elif event_type == \"labevents\":\n",
    "                value_match = re.search(r\":\\s*([\\d.]+)\", event_value)\n",
    "                if value_match:\n",
    "                    value = float(value_match.group(1))\n",
    "                else:\n",
    "                    value = 0\n",
    "                duration = 0\n",
    "            elif event_type == \"microbiologyevents\":\n",
    "                value, duration = 0, 0\n",
    "            elif event_type == \"prescriptions\":\n",
    "                value_match = re.search(r\"prescribed dose:\\s*([\\d.]+)\", event_value)\n",
    "                if value_match:\n",
    "                    value = float(value_match.group(1))\n",
    "                else:\n",
    "                    value = 0\n",
    "                duration_match = re.search(r\"duration:\\s*([\\d.]+)\", event_value)\n",
    "                if duration_match:\n",
    "                    duration = float(duration_match.group(1))\n",
    "                else:\n",
    "                    duration = 0\n",
    "            elif event_type == \"transfers\":\n",
    "                value, duration = 0, 0\n",
    "            elif event_type == \"procedureevents\":\n",
    "                value = 0\n",
    "                duration_match = re.search(r\"for\\s*([\\d.]+)\\s*hour\", event_value)\n",
    "                if duration_match:\n",
    "                    duration = float(duration_match.group(1))\n",
    "                else:\n",
    "                    duration = 0\n",
    "            else:\n",
    "                raise ValueError(f\"Unknown event type: {event_type}\")\n",
    "        except Exception as e:\n",
    "            value, duration = 0, 0\n",
    "            logging.warning(f\"Error {e} in extracting digits from event tuple: {event_tuple}\")\n",
    "        return value, duration\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        data = self.qa.iloc[index]\n",
    "        q = data[\"q\"]\n",
    "        a = data[\"a\"]\n",
    "        source = data[\"source\"]\n",
    "        hadm_id = data[\"hadm_id\"]\n",
    "        event_emb = self._get_event_emb(data[\"hadm_id\"])\n",
    "        num_events = event_emb.shape[0]\n",
    "        event_list = self._get_event_list(data[\"hadm_id\"])\n",
    "        assert len(event_list) == num_events\n",
    "        time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)\n",
    "        value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)\n",
    "        event_emb = torch.cat(\n",
    "            [\n",
    "                event_emb,\n",
    "                time_tensor,\n",
    "                value_duration_tensor,\n",
    "            ],\n",
    "            dim=1\n",
    "        )\n",
    "        final_q = \"\\n\".join([\"<image>\" * num_events, q])\n",
    "        return final_q, a, event_emb, source, hadm_id"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "8d5594cb",
   "metadata": {},
   "source": [
    "dataset = EvalInstructionTuningDataset()\n",
    "q, a, event_emb, source, hadm_id = dataset[0]\n",
    "print(q)\n",
    "print(a)\n",
    "print(source)\n",
    "print(hadm_id)\n",
    "print(event_emb.shape)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "241e1241",
   "metadata": {},
   "source": [
    "from src.model.modeling_llemr import LlemrForConditionalGeneration\n",
    "from src.model.init_llemr import init_llemr\n",
    "from transformers import AutoTokenizer\n",
    "from src.model.modeling_dummy import DummyModel\n",
    "from peft import PeftModel\n",
    "\n",
    "device = \"cuda:0\"\n",
    "llm_pretrained_model_name_or_path = \"lmsys/vicuna-7b-v1.5\"\n",
    "lora_name_or_path = \"zzachw12/llemr-v1\"\n",
    "model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, 1027)\n",
    "model.to(torch.bfloat16)\n",
    "model = PeftModel.from_pretrained(model, lora_name_or_path)\n",
    "model.to(device)\n",
    "model.eval()\n",
    "sys_prompt = \"You are an AI assistant specialized in analyzing ICU patient data.\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "bfd7ff8a",
   "metadata": {},
   "source": [
    "model.dtype"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "19a04f7d",
   "metadata": {},
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "all_responses = {}\n",
    "for q, a, event_emb, source, hadm_id in tqdm(dataset):\n",
    "    message = [\n",
    "        {\"role\": \"system\", \"content\": sys_prompt},\n",
    "        {\"role\": \"user\", \"content\": q},\n",
    "    ]\n",
    "    message = tokenizer.apply_chat_template(\n",
    "        message,\n",
    "        tokenize=False,\n",
    "        add_generation_prompt=True\n",
    "    )\n",
    "    inputs = tokenizer(\n",
    "        message,\n",
    "        return_tensors=\"pt\",\n",
    "        padding=True,\n",
    "        truncation=True,\n",
    "        add_special_tokens=False,\n",
    "    )\n",
    "    inputs = inputs.to(device)\n",
    "    event_emb = event_emb.unsqueeze(1).to(device)\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"],\n",
    "        attention_mask=inputs[\"attention_mask\"],\n",
    "        pixel_values=event_emb,\n",
    "        max_new_tokens=256\n",
    "    )\n",
    "    generated_text = tokenizer.decode(outputs[0][len(inputs[\"input_ids\"][0]):], skip_special_tokens=True)\n",
    "    all_responses[(source, hadm_id)] = generated_text"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a72e85e8",
   "metadata": {},
   "source": [
    "print(f\"Processed {len(all_responses)} responses\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c4cfc894",
   "metadata": {},
   "source": "create_directory(os.path.join(model_path, \"llemr_vicuna/qa_output\"))",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7e65eb22",
   "metadata": {},
   "source": [
    "import json\n",
    "\n",
    "\n",
    "with open(os.path.join(model_path, \"llemr_vicuna/qa_output/answer.jsonl\"), \"w\") as file:\n",
    "    for _, data in dataset.qa.iterrows():\n",
    "        a_hat = all_responses.get((data.source, data.hadm_id), \"\")\n",
    "        json_string = json.dumps({\"hadm_id\": data.hadm_id, \"q\": data.q, \"a\": data.a, \"a_hat\": a_hat, \"source\": data.source})\n",
    "        file.write(json_string + '\\n')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e4424b6a",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "llm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}