Diff of /src/eval/eval.ipynb [000000] .. [780764]

Switch to unified view

a b/src/eval/eval.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "id": "afdeba94",
6
   "metadata": {},
7
   "source": [
8
    "import os\n",
9
    "import sys\n",
10
    "\n",
11
    "src_path = os.path.abspath(\"../..\")\n",
12
    "print(src_path)\n",
13
    "sys.path.append(src_path)"
14
   ],
15
   "outputs": [],
16
   "execution_count": null
17
  },
18
  {
19
   "cell_type": "code",
20
   "id": "0d5f2e19",
21
   "metadata": {},
22
   "source": "from src.utils import processed_data_path, set_seed, remote_project_path",
23
   "outputs": [],
24
   "execution_count": null
25
  },
26
  {
27
   "cell_type": "code",
28
   "id": "e00815d2",
29
   "metadata": {},
30
   "source": [
31
    "set_seed(seed=42)"
32
   ],
33
   "outputs": [],
34
   "execution_count": null
35
  },
36
  {
37
   "metadata": {},
38
   "cell_type": "code",
39
   "outputs": [],
40
   "execution_count": null,
41
   "source": "import pandas as pd",
42
   "id": "2267de3ef42c4424"
43
  },
44
  {
45
   "metadata": {},
46
   "cell_type": "code",
47
   "source": "answer_filename = \"llemr_vicuna\"",
48
   "id": "14e9cc774ec18d0f",
49
   "outputs": [],
50
   "execution_count": null
51
  },
52
  {
53
   "metadata": {},
54
   "cell_type": "code",
55
   "source": "model_path = os.path.join(remote_project_path, \"output\")",
56
   "id": "921de6717a04852e",
57
   "outputs": [],
58
   "execution_count": null
59
  },
60
  {
61
   "cell_type": "code",
62
   "id": "ef32981d",
63
   "metadata": {},
64
   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
65
   "outputs": [],
66
   "execution_count": null
67
  },
68
  {
69
   "cell_type": "code",
70
   "id": "16f23fa5",
71
   "metadata": {},
72
   "source": [
73
    "b_answer = pd.read_json(os.path.join(model_path, f\"gpt4/qa_output/answer.jsonl\"), lines=True)\n",
74
    "b_answer.a_hat = b_answer.a_hat.replace(\"\", float(\"nan\"))\n",
75
    "b_answer = b_answer.dropna()\n",
76
    "b_answer"
77
   ],
78
   "outputs": [],
79
   "execution_count": null
80
  },
81
  {
82
   "cell_type": "code",
83
   "id": "539a6392",
84
   "metadata": {},
85
   "source": [
86
    "answer = pd.read_json(os.path.join(model_path, f\"{answer_filename}/qa_output/answer.jsonl\"), lines=True)\n",
87
    "answer"
88
   ],
89
   "outputs": [],
90
   "execution_count": null
91
  },
92
  {
93
   "cell_type": "code",
94
   "id": "821203ad",
95
   "metadata": {},
96
   "source": [
97
    "answer = b_answer.merge(answer, on=[\"hadm_id\", \"q\", \"a\", \"source\"])\n",
98
    "answer"
99
   ],
100
   "outputs": [],
101
   "execution_count": null
102
  },
103
  {
104
   "cell_type": "code",
105
   "id": "53830063",
106
   "metadata": {},
107
   "source": [
108
    "system_content = \"\"\"You are a helpful and precise assistant for evaluating the quality of responses.\n",
109
    "\n",
110
    "Please assess the performance of two clinical AI assistants based on the question and the ground-truth answer provided below.\n",
111
    "\n",
112
    "Your evaluation should consider helpfulness, relevance, accuracy, and level of detail.\n",
113
    "\n",
114
    "Rate each AI assistant's response with a single score on a scale of 1 to 10, where 10 represents excellent performance.\n",
115
    "\n",
116
    "Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\n",
117
    "\n",
118
    "In the subsequent line, provide a concise explanation of your evaluation.\n",
119
    "\n",
120
    "Avoid any potential bias and ensure that the order in which the responses were presented does not affect your judgment.\"\"\""
121
   ],
122
   "outputs": [],
123
   "execution_count": null
124
  },
125
  {
126
   "cell_type": "code",
127
   "id": "0a0d8d6c",
128
   "metadata": {},
129
   "source": [
130
    "def generate_user_content(q, a, a_hat_1, a_hat_2):\n",
131
    "    return f\"\"\"[Question]\n",
132
    "{q}\n",
133
    "[End of Question]\n",
134
    "    \n",
135
    "[Ground-truth Answer]\n",
136
    "{a}\n",
137
    "[End of Ground-truth Answer]\n",
138
    "\n",
139
    "[Assistant 1 Answer]\n",
140
    "{a_hat_1}\n",
141
    "[End of Assistant 1 Answer]\n",
142
    "\n",
143
    "[Assistant 2 Answer]\n",
144
    "{a_hat_2}\n",
145
    "[End of Assistant 2 Answer]\"\"\""
146
   ],
147
   "outputs": [],
148
   "execution_count": null
149
  },
150
  {
151
   "cell_type": "code",
152
   "id": "f064720f",
153
   "metadata": {},
154
   "source": [
155
    "prompts = {}\n",
156
    "for _, data in answer.iterrows():\n",
157
    "    messages = [{\"role\": \"system\", \"content\": system_content},\n",
158
    "                {\"role\": \"user\", \"content\": generate_user_content(data.q, data.a, data.a_hat_x, data.a_hat_y)}]\n",
159
    "    prompts[(data.source, data.hadm_id)] = messages\n",
160
    "len(prompts)"
161
   ],
162
   "outputs": [],
163
   "execution_count": null
164
  },
165
  {
166
   "cell_type": "code",
167
   "id": "59cf2588",
168
   "metadata": {},
169
   "source": [
170
    "import asyncio\n",
171
    "from openai import AsyncAzureOpenAI\n",
172
    "\n",
173
    "# TODO: Enter your credentials\n",
174
    "async_client = AsyncAzureOpenAI(\n",
175
    "    azure_endpoint=\"\",\n",
176
    "    api_key=\"\",\n",
177
    "    api_version=\"\"\n",
178
    ")"
179
   ],
180
   "outputs": [],
181
   "execution_count": null
182
  },
183
  {
184
   "cell_type": "code",
185
   "id": "6d3b1c82",
186
   "metadata": {},
187
   "source": [
188
    "async def generate_chat_response(async_client, prompt):\n",
189
    "    chat_params = {\n",
190
    "        \"model\": \"gpt-3.5-turbo\",\n",
191
    "        \"messages\": prompt,\n",
192
    "        \"max_tokens\": 512,\n",
193
    "        \"temperature\": 0.0,\n",
194
    "    }\n",
195
    "    try:\n",
196
    "        response = await async_client.chat.completions.create(**chat_params)\n",
197
    "    except Exception as e:\n",
198
    "        print(f\"Error in call_async: {e}\")\n",
199
    "        time.sleep(10)\n",
200
    "        print(f\"Sleep for 10s...\")\n",
201
    "        return -1\n",
202
    "    return response.choices[0].message.content"
203
   ],
204
   "outputs": [],
205
   "execution_count": null
206
  },
207
  {
208
   "cell_type": "code",
209
   "id": "b4ebea20",
210
   "metadata": {},
211
   "source": [
212
    "import time\n",
213
    "\n",
214
    "\n",
215
    "async def process_prompts(prompts):\n",
216
    "    # Gather all the futures together and wait for them to complete\n",
217
    "    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))\n",
218
    "    return responses"
219
   ],
220
   "outputs": [],
221
   "execution_count": null
222
  },
223
  {
224
   "cell_type": "code",
225
   "id": "aae93763",
226
   "metadata": {},
227
   "source": [
228
    "def chunk_list(lst, chunk_size):\n",
229
    "    \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
230
    "    for i in range(0, len(lst), chunk_size):\n",
231
    "        yield lst[i:i + chunk_size]"
232
   ],
233
   "outputs": [],
234
   "execution_count": null
235
  },
236
  {
237
   "cell_type": "code",
238
   "id": "d4ff8a13",
239
   "metadata": {},
240
   "source": [
241
    "from tqdm.asyncio import tqdm\n",
242
    "\n",
243
    "\n",
244
    "async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
245
    "    all_responses = {}\n",
246
    "\n",
247
    "    for i in range(repeat):\n",
248
    "\n",
249
    "        print(f\"round {i}\")\n",
250
    "        prev_n_responses = len(all_responses)\n",
251
    "\n",
252
    "        prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
253
    "\n",
254
    "        # Chunk the prompts into batches\n",
255
    "        prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
256
    "\n",
257
    "        for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
258
    "            batch_v = [prompts[k] for k in batch_k]\n",
259
    "            responses = await process_prompts(batch_v)\n",
260
    "            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
261
    "        print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
262
    "\n",
263
    "    return all_responses"
264
   ],
265
   "outputs": [],
266
   "execution_count": null
267
  },
268
  {
269
   "cell_type": "code",
270
   "id": "dfac8357",
271
   "metadata": {},
272
   "source": [
273
    "# Choose an appropriate batch size\n",
274
    "batch_size = 10  # Adjust based on your system and API limits\n",
275
    "\n",
276
    "# Assuming we are in an async environment\n",
277
    "responses = await process_prompts_in_batches(prompts, batch_size)\n",
278
    "print(f\"Processed {len(responses)} responses\")"
279
   ],
280
   "outputs": [],
281
   "execution_count": null
282
  },
283
  {
284
   "cell_type": "code",
285
   "id": "36317063",
286
   "metadata": {},
287
   "source": [
288
    "def split_responase(r, verbose=False):\n",
289
    "    if verbose:\n",
290
    "        print(r)\n",
291
    "    split_text = r.split(\"\\n\", 1)\n",
292
    "    scores = split_text[0].split(\" \")\n",
293
    "    base_score = float(scores[0])\n",
294
    "    score = float(scores[1])\n",
295
    "    comment = split_text[1].strip() if len(split_text) > 1 else \"\"\n",
296
    "    if verbose:\n",
297
    "        print(\"scores:\", scores)\n",
298
    "        print(\"comment:\", comment)\n",
299
    "    return base_score, score, comment"
300
   ],
301
   "outputs": [],
302
   "execution_count": null
303
  },
304
  {
305
   "cell_type": "code",
306
   "id": "f910eb36",
307
   "metadata": {},
308
   "source": [
309
    "responses_split = {}\n",
310
    "for k, r in responses.items():\n",
311
    "    responses_split[k] = split_responase(r)"
312
   ],
313
   "outputs": [],
314
   "execution_count": null
315
  },
316
  {
317
   "cell_type": "code",
318
   "id": "7e65eb22",
319
   "metadata": {},
320
   "source": [
321
    "import json\n",
322
    "\n",
323
    "with open(os.path.join(model_path, f\"{answer_filename}/qa_output/answer_eval.jsonl\"), \"w\") as file:\n",
324
    "    c = 0\n",
325
    "    for _, data in answer.iterrows():\n",
326
    "        if (data.source, data.hadm_id) in responses_split:\n",
327
    "            base_score, score, comment = responses_split[(data.source, data.hadm_id)]\n",
328
    "            json_string = json.dumps({\n",
329
    "                \"hadm_id\": data.hadm_id,\n",
330
    "                \"q\": data.q,\n",
331
    "                \"a\": data.a,\n",
332
    "                \"a_hat\": data.a_hat_y,\n",
333
    "                \"score\": score,\n",
334
    "                \"base_a_hat\": data.a_hat_x,\n",
335
    "                \"base_score\": base_score,\n",
336
    "                \"comment\": comment,\n",
337
    "                \"source\": data.source\n",
338
    "            })\n",
339
    "            file.write(json_string + '\\n')\n",
340
    "            c += 1\n",
341
    "c"
342
   ],
343
   "outputs": [],
344
   "execution_count": null
345
  },
346
  {
347
   "metadata": {},
348
   "cell_type": "code",
349
   "outputs": [],
350
   "execution_count": null,
351
   "source": "",
352
   "id": "41cde3512e408fbc"
353
  }
354
 ],
355
 "metadata": {
356
  "kernelspec": {
357
   "display_name": "llm",
358
   "language": "python",
359
   "name": "llm"
360
  },
361
  "language_info": {
362
   "codemirror_mode": {
363
    "name": "ipython",
364
    "version": 3
365
   },
366
   "file_extension": ".py",
367
   "mimetype": "text/x-python",
368
   "name": "python",
369
   "nbconvert_exporter": "python",
370
   "pygments_lexer": "ipython3",
371
   "version": "3.9.19"
372
  }
373
 },
374
 "nbformat": 4,
375
 "nbformat_minor": 5
376
}