Diff of /src/eval/query_gpt4.ipynb [000000] .. [780764]

Switch to unified view

a b/src/eval/query_gpt4.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "id": "afdeba94",
6
   "metadata": {},
7
   "source": [
8
    "import os\n",
9
    "import sys\n",
10
    "\n",
11
    "src_path = os.path.abspath(\"../..\")\n",
12
    "print(src_path)\n",
13
    "sys.path.append(src_path)"
14
   ],
15
   "outputs": [],
16
   "execution_count": null
17
  },
18
  {
19
   "cell_type": "code",
20
   "id": "0d5f2e19",
21
   "metadata": {},
22
   "source": "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path",
23
   "outputs": [],
24
   "execution_count": null
25
  },
26
  {
27
   "cell_type": "code",
28
   "id": "e00815d2",
29
   "metadata": {},
30
   "source": [
31
    "set_seed(seed=42)"
32
   ],
33
   "outputs": [],
34
   "execution_count": null
35
  },
36
  {
37
   "cell_type": "code",
38
   "id": "fd92d900",
39
   "metadata": {},
40
   "source": [
41
    "import pandas as pd"
42
   ],
43
   "outputs": [],
44
   "execution_count": null
45
  },
46
  {
47
   "metadata": {},
48
   "cell_type": "code",
49
   "source": "model_path = os.path.join(remote_project_path, \"output\")",
50
   "id": "4e04fa3e4c08145",
51
   "outputs": [],
52
   "execution_count": null
53
  },
54
  {
55
   "cell_type": "code",
56
   "id": "ef32981d",
57
   "metadata": {},
58
   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
59
   "outputs": [],
60
   "execution_count": null
61
  },
62
  {
63
   "cell_type": "code",
64
   "id": "539a6392",
65
   "metadata": {},
66
   "source": [
67
    "cohort = pd.read_csv(os.path.join(output_path, \"cohort_test_subset.csv\"))\n",
68
    "print(cohort.shape)\n",
69
    "cohort.head()"
70
   ],
71
   "outputs": [],
72
   "execution_count": null
73
  },
74
  {
75
   "cell_type": "code",
76
   "id": "6659ecf8",
77
   "metadata": {},
78
   "source": [
79
    "hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
80
    "len(hadm_ids)"
81
   ],
82
   "outputs": [],
83
   "execution_count": null
84
  },
85
  {
86
   "cell_type": "code",
87
   "id": "4090a70d",
88
   "metadata": {},
89
   "source": [
90
    "qa = pd.read_csv(os.path.join(output_path, \"qa_test_subset.csv\"))\n",
91
    "qa[\"source\"] = qa.event_type.apply(lambda x: \"note\" if pd.isna(x) else \"event\")\n",
92
    "qa"
93
   ],
94
   "outputs": [],
95
   "execution_count": null
96
  },
97
  {
98
   "cell_type": "code",
99
   "id": "dc67f44e",
100
   "metadata": {},
101
   "source": [
102
    "def get_events(hadm_id):\n",
103
    "    df = pd.read_csv(os.path.join(output_path, f\"event_selected/event_{hadm_id}.csv\"))  \n",
104
    "    text = []\n",
105
    "    for i, row in df.iterrows():\n",
106
    "        text.append(f\"{row.timestamp:.2f} hour, {row.event_type}, {row.event_value}\")\n",
107
    "    return \"\\n\".join(text)"
108
   ],
109
   "outputs": [],
110
   "execution_count": null
111
  },
112
  {
113
   "cell_type": "code",
114
   "id": "1f080932",
115
   "metadata": {},
116
   "source": [
117
    "print(get_events(qa.iloc[2].hadm_id))"
118
   ],
119
   "outputs": [],
120
   "execution_count": null
121
  },
122
  {
123
   "cell_type": "code",
124
   "id": "53830063",
125
   "metadata": {},
126
   "source": [
127
    "system_content = \"\"\"You are an AI assistant specialized in analyzing ICU patient data.\n",
128
    "You are given a sequence of clinical events from an ICU patient's hospital admission.\n",
129
    "Each event is formatted as follows: {time elapsed after admission (in hours)}, {event type}, {event value}.\n",
130
    "Based on this sequence of events, provide a concise and accurate answer to the question below.\n",
131
    "Keep your response within 256 tokens.\"\"\""
132
   ],
133
   "outputs": [],
134
   "execution_count": null
135
  },
136
  {
137
   "cell_type": "code",
138
   "id": "3fd4ea61",
139
   "metadata": {},
140
   "source": [
141
    "messages = [{\"role\": \"system\", \"content\": system_content},\n",
142
    "            {\"role\": \"user\", \"content\": f\"{qa.iloc[0].q}\\n\\n\" + get_events(qa.iloc[0].hadm_id)}]"
143
   ],
144
   "outputs": [],
145
   "execution_count": null
146
  },
147
  {
148
   "cell_type": "code",
149
   "id": "dfd9801f",
150
   "metadata": {},
151
   "source": [
152
    "print(messages[0][\"content\"])"
153
   ],
154
   "outputs": [],
155
   "execution_count": null
156
  },
157
  {
158
   "cell_type": "code",
159
   "id": "248cd830",
160
   "metadata": {},
161
   "source": [
162
    "print(messages[1][\"content\"])"
163
   ],
164
   "outputs": [],
165
   "execution_count": null
166
  },
167
  {
168
   "cell_type": "code",
169
   "id": "f064720f",
170
   "metadata": {},
171
   "source": [
172
    "prompts = {}\n",
173
    "for _, data in qa.iterrows():\n",
174
    "    messages = [{\"role\": \"system\", \"content\": system_content},\n",
175
    "                {\"role\": \"user\", \"content\": f\"{data.q}\\n\\n\" + get_events(data.hadm_id)}]\n",
176
    "    prompts[(data.source, data.hadm_id)] = messages\n",
177
    "len(prompts)"
178
   ],
179
   "outputs": [],
180
   "execution_count": null
181
  },
182
  {
183
   "cell_type": "code",
184
   "id": "3c1734d5",
185
   "metadata": {},
186
   "source": [
187
    "prompts[(\"note\", qa.iloc[0].hadm_id)]"
188
   ],
189
   "outputs": [],
190
   "execution_count": null
191
  },
192
  {
193
   "cell_type": "code",
194
   "id": "52c4c16f",
195
   "metadata": {},
196
   "source": [
197
    "import tiktoken\n",
198
    "\n",
199
    "\n",
200
    "def num_tokens_from_message(message):\n",
201
    "    encoding = tiktoken.encoding_for_model(\"gpt-4\")\n",
202
    "    return len(encoding.encode(message[0][\"content\"])) + len(encoding.encode(message[1][\"content\"])) + 11    "
203
   ],
204
   "outputs": [],
205
   "execution_count": null
206
  },
207
  {
208
   "cell_type": "code",
209
   "id": "e90c113b",
210
   "metadata": {},
211
   "source": [
212
    "num_tokens_from_message(messages)"
213
   ],
214
   "outputs": [],
215
   "execution_count": null
216
  },
217
  {
218
   "cell_type": "code",
219
   "id": "b42067ab",
220
   "metadata": {},
221
   "source": [
222
    "prompts_num_tokens = {}\n",
223
    "for k, v in prompts.items():\n",
224
    "    prompts_num_tokens[k] = num_tokens_from_message(v)"
225
   ],
226
   "outputs": [],
227
   "execution_count": null
228
  },
229
  {
230
   "cell_type": "code",
231
   "id": "d80f78b7",
232
   "metadata": {},
233
   "source": [
234
    "import numpy as np\n",
235
    "\n",
236
    "\n",
237
    "print(\"mean: \", np.mean(list(prompts_num_tokens.values())))\n",
238
    "print(\"std: \", np.std(list(prompts_num_tokens.values())))\n",
239
    "print(\"min: \", np.min(list(prompts_num_tokens.values())))\n",
240
    "print(\"max: \", np.max(list(prompts_num_tokens.values())))\n",
241
    "print(\"25th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 25))\n",
242
    "print(\"50th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 50))\n",
243
    "print(\"75th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 75))"
244
   ],
245
   "outputs": [],
246
   "execution_count": null
247
  },
248
  {
249
   "cell_type": "code",
250
   "id": "40ea7c46",
251
   "metadata": {},
252
   "source": [
253
    "max_response_tokens = 256\n",
254
    "token_limit = 128000"
255
   ],
256
   "outputs": [],
257
   "execution_count": null
258
  },
259
  {
260
   "cell_type": "code",
261
   "id": "b99ad70e",
262
   "metadata": {},
263
   "source": [
264
    "import copy\n",
265
    "\n",
266
    "\n",
267
    "def trim_message(message):\n",
268
    "    trimmed_message = copy.deepcopy(message)\n",
269
    "    encoding = tiktoken.encoding_for_model(\"gpt-4\")\n",
270
    "    system_tokens = len(encoding.encode(message[0][\"content\"]))\n",
271
    "    user_tokens = len(encoding.encode(message[1][\"content\"]))\n",
272
    "    \n",
273
    "    # If the total tokens are within the limit, no trimming is needed\n",
274
    "    if system_tokens + user_tokens + 11 + max_response_tokens <= token_limit:\n",
275
    "        return trimmed_message\n",
276
    "    \n",
277
    "    # Otherwise, trim the user message content\n",
278
    "    available_tokens = token_limit - system_tokens - 11 - max_response_tokens\n",
279
    "    trimmed_user_content = encoding.decode(encoding.encode(message[1][\"content\"])[:available_tokens])\n",
280
    "    \n",
281
    "    # Update the message with the trimmed content\n",
282
    "    trimmed_message[1][\"content\"] = trimmed_user_content\n",
283
    "    return trimmed_message"
284
   ],
285
   "outputs": [],
286
   "execution_count": null
287
  },
288
  {
289
   "cell_type": "code",
290
   "id": "c836f4cf",
291
   "metadata": {},
292
   "source": [
293
    "trimmed_prompts = {}\n",
294
    "for k, v in prompts.items():\n",
295
    "    trimmed_v = trim_message(v)\n",
296
    "    if trimmed_v != v:\n",
297
    "        print(f\"{k} is trimmed\")\n",
298
    "    trimmed_prompts[k] = trim_message(v)\n",
299
    "len(trimmed_prompts)"
300
   ],
301
   "outputs": [],
302
   "execution_count": null
303
  },
304
  {
305
   "cell_type": "code",
306
   "id": "59cf2588",
307
   "metadata": {},
308
   "source": [
309
    "import asyncio\n",
310
    "from openai import AsyncAzureOpenAI\n",
311
    "\n",
312
    "\n",
313
    "# TODO: Enter your credentials\n",
314
    "async_client = AsyncAzureOpenAI(\n",
315
    "    azure_endpoint=\"\",\n",
316
    "    api_key=\"\",\n",
317
    "    api_version=\"\"\n",
318
    ")"
319
   ],
320
   "outputs": [],
321
   "execution_count": null
322
  },
323
  {
324
   "cell_type": "code",
325
   "id": "6d3b1c82",
326
   "metadata": {},
327
   "source": [
328
    "async def generate_chat_response(async_client, prompt):\n",
329
    "    chat_params = {\n",
330
    "        \"model\": \"gpt-4\",\n",
331
    "        \"messages\": prompt,\n",
332
    "        \"max_tokens\": max_response_tokens,\n",
333
    "        \"temperature\": 0.0,\n",
334
    "    }\n",
335
    "    try:\n",
336
    "        response = await async_client.chat.completions.create(**chat_params)\n",
337
    "    except Exception as e:\n",
338
    "        print(f\"Error in call_async: {e}\")\n",
339
    "        time.sleep(10)\n",
340
    "        print(f\"Sleep for 10s...\")\n",
341
    "        return -1\n",
342
    "    return response.choices[0].message.content"
343
   ],
344
   "outputs": [],
345
   "execution_count": null
346
  },
347
  {
348
   "cell_type": "code",
349
   "id": "b4ebea20",
350
   "metadata": {},
351
   "source": [
352
    "import time\n",
353
    "\n",
354
    "\n",
355
    "async def process_prompts(prompts):\n",
356
    "    # Gather all the futures together and wait for them to complete\n",
357
    "    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))        \n",
358
    "    return responses"
359
   ],
360
   "outputs": [],
361
   "execution_count": null
362
  },
363
  {
364
   "cell_type": "code",
365
   "id": "aae93763",
366
   "metadata": {},
367
   "source": [
368
    "def chunk_list(lst, chunk_size):\n",
369
    "    \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
370
    "    for i in range(0, len(lst), chunk_size):\n",
371
    "        yield lst[i:i + chunk_size]"
372
   ],
373
   "outputs": [],
374
   "execution_count": null
375
  },
376
  {
377
   "cell_type": "code",
378
   "id": "d4ff8a13",
379
   "metadata": {},
380
   "source": [
381
    "from tqdm.asyncio import tqdm\n",
382
    "\n",
383
    "\n",
384
    "async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
385
    "    all_responses = {}\n",
386
    "    \n",
387
    "    for i in range(repeat):\n",
388
    "        \n",
389
    "        print(f\"round {i}\")\n",
390
    "        prev_n_responses = len(all_responses)\n",
391
    "        \n",
392
    "        prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
393
    "\n",
394
    "        # Chunk the prompts into batches\n",
395
    "        prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
396
    "\n",
397
    "        for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
398
    "            batch_v = [prompts[k] for k in batch_k]\n",
399
    "            responses = await process_prompts(batch_v)\n",
400
    "            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
401
    "        print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
402
    "    \n",
403
    "    return all_responses"
404
   ],
405
   "outputs": [],
406
   "execution_count": null
407
  },
408
  {
409
   "cell_type": "code",
410
   "id": "dfac8357",
411
   "metadata": {},
412
   "source": [
413
    "# Choose an appropriate batch size\n",
414
    "batch_size = 10  # Adjust based on your system and API limits\n",
415
    "\n",
416
    "# Assuming we are in an async environment\n",
417
    "responses = await process_prompts_in_batches(trimmed_prompts, batch_size)\n",
418
    "print(f\"Processed {len(responses)} responses\")"
419
   ],
420
   "outputs": [],
421
   "execution_count": null
422
  },
423
  {
424
   "cell_type": "code",
425
   "id": "7e65eb22",
426
   "metadata": {},
427
   "source": [
428
    "import json\n",
429
    "\n",
430
    "\n",
431
    "with open(os.path.join(model_path, \"gpt4/qa_output/answer.jsonl\"), \"w\") as file:\n",
432
    "    for _, data in qa.iterrows():\n",
433
    "        a_hat = responses.get((data.source, data.hadm_id), \"\")\n",
434
    "        json_string = json.dumps({\"hadm_id\": data.hadm_id, \"q\": data.q, \"a\": data.a, \"a_hat\": a_hat, \"source\": data.source})\n",
435
    "        file.write(json_string + '\\n')"
436
   ],
437
   "outputs": [],
438
   "execution_count": null
439
  },
440
  {
441
   "cell_type": "code",
442
   "id": "e4424b6a",
443
   "metadata": {},
444
   "source": [],
445
   "outputs": [],
446
   "execution_count": null
447
  }
448
 ],
449
 "metadata": {
450
  "kernelspec": {
451
   "display_name": "llm",
452
   "language": "python",
453
   "name": "llm"
454
  },
455
  "language_info": {
456
   "codemirror_mode": {
457
    "name": "ipython",
458
    "version": 3
459
   },
460
   "file_extension": ".py",
461
   "mimetype": "text/x-python",
462
   "name": "python",
463
   "nbconvert_exporter": "python",
464
   "pygments_lexer": "ipython3",
465
   "version": "3.9.19"
466
  }
467
 },
468
 "nbformat": 4,
469
 "nbformat_minor": 5
470
}