Switch to unified view

a b/src/eval/query_llemr.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "id": "afdeba94",
6
   "metadata": {},
7
   "source": [
8
    "import os\n",
9
    "import sys\n",
10
    "\n",
11
    "src_path = os.path.abspath(\"../..\")\n",
12
    "print(src_path)\n",
13
    "sys.path.append(src_path)"
14
   ],
15
   "outputs": [],
16
   "execution_count": null
17
  },
18
  {
19
   "cell_type": "code",
20
   "id": "0d5f2e19",
21
   "metadata": {},
22
   "source": "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path",
23
   "outputs": [],
24
   "execution_count": null
25
  },
26
  {
27
   "cell_type": "code",
28
   "id": "e00815d2",
29
   "metadata": {},
30
   "source": [
31
    "set_seed(seed=42)"
32
   ],
33
   "outputs": [],
34
   "execution_count": null
35
  },
36
  {
37
   "cell_type": "code",
38
   "id": "fd92d900",
39
   "metadata": {},
40
   "source": [
41
    "import pandas as pd"
42
   ],
43
   "outputs": [],
44
   "execution_count": null
45
  },
46
  {
47
   "metadata": {},
48
   "cell_type": "code",
49
   "source": "model_path = os.path.join(remote_project_path, \"output\")",
50
   "id": "4b426270718efb82",
51
   "outputs": [],
52
   "execution_count": null
53
  },
54
  {
55
   "cell_type": "code",
56
   "id": "ef32981d",
57
   "metadata": {},
58
   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
59
   "outputs": [],
60
   "execution_count": null
61
  },
62
  {
63
   "cell_type": "code",
64
   "id": "539a6392",
65
   "metadata": {},
66
   "source": [
67
    "cohort = pd.read_csv(os.path.join(output_path, \"cohort_test_subset.csv\"))\n",
68
    "print(cohort.shape)\n",
69
    "cohort.head()"
70
   ],
71
   "outputs": [],
72
   "execution_count": null
73
  },
74
  {
75
   "cell_type": "code",
76
   "id": "6659ecf8",
77
   "metadata": {},
78
   "source": [
79
    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
80
    "len(hadm_ids)"
81
   ],
82
   "outputs": [],
83
   "execution_count": null
84
  },
85
  {
86
   "cell_type": "code",
87
   "id": "3f8cb6ae",
88
   "metadata": {},
89
   "source": [
90
    "import logging\n",
91
    "import os\n",
92
    "\n",
93
    "import pandas as pd\n",
94
    "import torch\n",
95
    "from torch.utils.data import Dataset\n",
96
    "import re\n",
97
    "\n",
98
    "from src.utils import processed_data_path\n",
99
    "\n",
100
    "\n",
101
    "class EvalInstructionTuningDataset(Dataset):\n",
102
    "    def __init__(self):\n",
103
    "        self.data_path = os.path.join(processed_data_path, f\"mimic4\")\n",
104
    "        qa = pd.read_csv(os.path.join(self.data_path, \"qa_test_subset.csv\"))\n",
105
    "        qa[\"source\"] = qa.event_type.apply(lambda x: \"note\" if pd.isna(x) else \"event\")\n",
106
    "        self.qa = qa\n",
107
    "        logging.warning(f\"Loaded {len(qa)} QA samples\")\n",
108
    "    \n",
109
    "    def _get_event_list(self, hadm_id):\n",
110
    "        df = pd.read_csv(os.path.join(self.data_path, f\"event_selected/event_{hadm_id}.csv\"))\n",
111
    "        event_list = []\n",
112
    "        for i, row in df.iterrows():\n",
113
    "            event_list.append((row.timestamp, row.event_type, row.event_value))\n",
114
    "        return event_list\n",
115
    "\n",
116
    "    def _get_event_emb(self, hadm_id):\n",
117
    "        return torch.load(os.path.join(self.data_path, f\"pt_event_selected_no_time_type/event_{hadm_id}.pt\"))\n",
118
    "\n",
119
    "    def __len__(self):\n",
120
    "        return len(self.qa)\n",
121
    "\n",
122
    "    @staticmethod\n",
123
    "    def _extract_digits(event_tuple):\n",
124
    "        timestamp, event_type, event_value = event_tuple\n",
125
    "        try:\n",
126
    "            if event_type == \"patient demographics\":\n",
127
    "                value_match = re.search(r\"age:\\s*([\\d.]+)\", event_value)\n",
128
    "                if value_match:\n",
129
    "                    value = float(value_match.group(1))\n",
130
    "                else:\n",
131
    "                    value = 0\n",
132
    "                duration = 0\n",
133
    "            elif event_type == \"admission info\":\n",
134
    "                value, duration = 0, 0\n",
135
    "            elif event_type == \"diagnoses_icd\":\n",
136
    "                value, duration = 0, 0\n",
137
    "            elif event_type == \"labevents\":\n",
138
    "                value_match = re.search(r\":\\s*([\\d.]+)\", event_value)\n",
139
    "                if value_match:\n",
140
    "                    value = float(value_match.group(1))\n",
141
    "                else:\n",
142
    "                    value = 0\n",
143
    "                duration = 0\n",
144
    "            elif event_type == \"microbiologyevents\":\n",
145
    "                value, duration = 0, 0\n",
146
    "            elif event_type == \"prescriptions\":\n",
147
    "                value_match = re.search(r\"prescribed dose:\\s*([\\d.]+)\", event_value)\n",
148
    "                if value_match:\n",
149
    "                    value = float(value_match.group(1))\n",
150
    "                else:\n",
151
    "                    value = 0\n",
152
    "                duration_match = re.search(r\"duration:\\s*([\\d.]+)\", event_value)\n",
153
    "                if duration_match:\n",
154
    "                    duration = float(duration_match.group(1))\n",
155
    "                else:\n",
156
    "                    duration = 0\n",
157
    "            elif event_type == \"transfers\":\n",
158
    "                value, duration = 0, 0\n",
159
    "            elif event_type == \"procedureevents\":\n",
160
    "                value = 0\n",
161
    "                duration_match = re.search(r\"for\\s*([\\d.]+)\\s*hour\", event_value)\n",
162
    "                if duration_match:\n",
163
    "                    duration = float(duration_match.group(1))\n",
164
    "                else:\n",
165
    "                    duration = 0\n",
166
    "            else:\n",
167
    "                raise ValueError(f\"Unknown event type: {event_type}\")\n",
168
    "        except Exception as e:\n",
169
    "            value, duration = 0, 0\n",
170
    "            logging.warning(f\"Error {e} in extracting digits from event tuple: {event_tuple}\")\n",
171
    "        return value, duration\n",
172
    "\n",
173
    "    def __getitem__(self, index):\n",
174
    "        data = self.qa.iloc[index]\n",
175
    "        q = data[\"q\"]\n",
176
    "        a = data[\"a\"]\n",
177
    "        source = data[\"source\"]\n",
178
    "        hadm_id = data[\"hadm_id\"]\n",
179
    "        event_emb = self._get_event_emb(data[\"hadm_id\"])\n",
180
    "        num_events = event_emb.shape[0]\n",
181
    "        event_list = self._get_event_list(data[\"hadm_id\"])\n",
182
    "        assert len(event_list) == num_events\n",
183
    "        time_tensor = torch.tensor([[e[0]] for e in event_list], dtype=torch.float32)\n",
184
    "        value_duration_tensor = torch.tensor([self._extract_digits(e) for e in event_list], dtype=torch.float32)\n",
185
    "        event_emb = torch.cat(\n",
186
    "            [\n",
187
    "                event_emb,\n",
188
    "                time_tensor,\n",
189
    "                value_duration_tensor,\n",
190
    "            ],\n",
191
    "            dim=1\n",
192
    "        )\n",
193
    "        final_q = \"\\n\".join([\"<image>\" * num_events, q])\n",
194
    "        return final_q, a, event_emb, source, hadm_id"
195
   ],
196
   "outputs": [],
197
   "execution_count": null
198
  },
199
  {
200
   "cell_type": "code",
201
   "id": "8d5594cb",
202
   "metadata": {},
203
   "source": [
204
    "dataset = EvalInstructionTuningDataset()\n",
205
    "q, a, event_emb, source, hadm_id = dataset[0]\n",
206
    "print(q)\n",
207
    "print(a)\n",
208
    "print(source)\n",
209
    "print(hadm_id)\n",
210
    "print(event_emb.shape)"
211
   ],
212
   "outputs": [],
213
   "execution_count": null
214
  },
215
  {
216
   "cell_type": "code",
217
   "id": "241e1241",
218
   "metadata": {},
219
   "source": [
220
    "from src.model.modeling_llemr import LlemrForConditionalGeneration\n",
221
    "from src.model.init_llemr import init_llemr\n",
222
    "from transformers import AutoTokenizer\n",
223
    "from src.model.modeling_dummy import DummyModel\n",
224
    "from peft import PeftModel\n",
225
    "\n",
226
    "device = \"cuda:0\"\n",
227
    "llm_pretrained_model_name_or_path = \"lmsys/vicuna-7b-v1.5\"\n",
228
    "lora_name_or_path = \"zzachw12/llemr-v1\"\n",
229
    "model, tokenizer = init_llemr(llm_pretrained_model_name_or_path, 1027)\n",
230
    "model.to(torch.bfloat16)\n",
231
    "model = PeftModel.from_pretrained(model, lora_name_or_path)\n",
232
    "model.to(device)\n",
233
    "model.eval()\n",
234
    "sys_prompt = \"You are an AI assistant specialized in analyzing ICU patient data.\""
235
   ],
236
   "outputs": [],
237
   "execution_count": null
238
  },
239
  {
240
   "cell_type": "code",
241
   "id": "bfd7ff8a",
242
   "metadata": {},
243
   "source": [
244
    "model.dtype"
245
   ],
246
   "outputs": [],
247
   "execution_count": null
248
  },
249
  {
250
   "cell_type": "code",
251
   "id": "19a04f7d",
252
   "metadata": {},
253
   "source": [
254
    "from tqdm import tqdm\n",
255
    "\n",
256
    "\n",
257
    "all_responses = {}\n",
258
    "for q, a, event_emb, source, hadm_id in tqdm(dataset):\n",
259
    "    message = [\n",
260
    "        {\"role\": \"system\", \"content\": sys_prompt},\n",
261
    "        {\"role\": \"user\", \"content\": q},\n",
262
    "    ]\n",
263
    "    message = tokenizer.apply_chat_template(\n",
264
    "        message,\n",
265
    "        tokenize=False,\n",
266
    "        add_generation_prompt=True\n",
267
    "    )\n",
268
    "    inputs = tokenizer(\n",
269
    "        message,\n",
270
    "        return_tensors=\"pt\",\n",
271
    "        padding=True,\n",
272
    "        truncation=True,\n",
273
    "        add_special_tokens=False,\n",
274
    "    )\n",
275
    "    inputs = inputs.to(device)\n",
276
    "    event_emb = event_emb.unsqueeze(1).to(device)\n",
277
    "    outputs = model.generate(\n",
278
    "        input_ids=inputs[\"input_ids\"],\n",
279
    "        attention_mask=inputs[\"attention_mask\"],\n",
280
    "        pixel_values=event_emb,\n",
281
    "        max_new_tokens=256\n",
282
    "    )\n",
283
    "    generated_text = tokenizer.decode(outputs[0][len(inputs[\"input_ids\"][0]):], skip_special_tokens=True)\n",
284
    "    all_responses[(source, hadm_id)] = generated_text"
285
   ],
286
   "outputs": [],
287
   "execution_count": null
288
  },
289
  {
290
   "cell_type": "code",
291
   "id": "a72e85e8",
292
   "metadata": {},
293
   "source": [
294
    "print(f\"Processed {len(all_responses)} responses\")"
295
   ],
296
   "outputs": [],
297
   "execution_count": null
298
  },
299
  {
300
   "cell_type": "code",
301
   "id": "c4cfc894",
302
   "metadata": {},
303
   "source": "create_directory(os.path.join(model_path, \"llemr_vicuna/qa_output\"))",
304
   "outputs": [],
305
   "execution_count": null
306
  },
307
  {
308
   "cell_type": "code",
309
   "id": "7e65eb22",
310
   "metadata": {},
311
   "source": [
312
    "import json\n",
313
    "\n",
314
    "\n",
315
    "with open(os.path.join(model_path, \"llemr_vicuna/qa_output/answer.jsonl\"), \"w\") as file:\n",
316
    "    for _, data in dataset.qa.iterrows():\n",
317
    "        a_hat = all_responses.get((data.source, data.hadm_id), \"\")\n",
318
    "        json_string = json.dumps({\"hadm_id\": data.hadm_id, \"q\": data.q, \"a\": data.a, \"a_hat\": a_hat, \"source\": data.source})\n",
319
    "        file.write(json_string + '\\n')"
320
   ],
321
   "outputs": [],
322
   "execution_count": null
323
  },
324
  {
325
   "cell_type": "code",
326
   "id": "e4424b6a",
327
   "metadata": {},
328
   "source": [],
329
   "outputs": [],
330
   "execution_count": null
331
  }
332
 ],
333
 "metadata": {
334
  "kernelspec": {
335
   "display_name": "llm",
336
   "language": "python",
337
   "name": "llm"
338
  },
339
  "language_info": {
340
   "codemirror_mode": {
341
    "name": "ipython",
342
    "version": 3
343
   },
344
   "file_extension": ".py",
345
   "mimetype": "text/x-python",
346
   "name": "python",
347
   "nbconvert_exporter": "python",
348
   "pygments_lexer": "ipython3",
349
   "version": "3.9.19"
350
  }
351
 },
352
 "nbformat": 4,
353
 "nbformat_minor": 5
354
}