Diff of /src/eval/eval.ipynb [000000] .. [780764]

Switch to side-by-side view

--- a
+++ b/src/eval/eval.ipynb
@@ -0,0 +1,376 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "id": "afdeba94",
+   "metadata": {},
+   "source": [
+    "import os\n",
+    "import sys\n",
+    "\n",
+    "src_path = os.path.abspath(\"../..\")\n",
+    "print(src_path)\n",
+    "sys.path.append(src_path)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "0d5f2e19",
+   "metadata": {},
+   "source": "from src.utils import processed_data_path, set_seed, remote_project_path",
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "e00815d2",
+   "metadata": {},
+   "source": [
+    "set_seed(seed=42)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "outputs": [],
+   "execution_count": null,
+   "source": "import pandas as pd",
+   "id": "2267de3ef42c4424"
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "source": "answer_filename = \"llemr_vicuna\"",
+   "id": "14e9cc774ec18d0f",
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "source": "model_path = os.path.join(remote_project_path, \"output\")",
+   "id": "921de6717a04852e",
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "ef32981d",
+   "metadata": {},
+   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "16f23fa5",
+   "metadata": {},
+   "source": [
+    "b_answer = pd.read_json(os.path.join(model_path, f\"gpt4/qa_output/answer.jsonl\"), lines=True)\n",
+    "b_answer.a_hat = b_answer.a_hat.replace(\"\", float(\"nan\"))\n",
+    "b_answer = b_answer.dropna()\n",
+    "b_answer"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "539a6392",
+   "metadata": {},
+   "source": [
+    "answer = pd.read_json(os.path.join(model_path, f\"{answer_filename}/qa_output/answer.jsonl\"), lines=True)\n",
+    "answer"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "821203ad",
+   "metadata": {},
+   "source": [
+    "answer = b_answer.merge(answer, on=[\"hadm_id\", \"q\", \"a\", \"source\"])\n",
+    "answer"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "53830063",
+   "metadata": {},
+   "source": [
+    "system_content = \"\"\"You are a helpful and precise assistant for evaluating the quality of responses.\n",
+    "\n",
+    "Please assess the performance of two clinical AI assistants based on the question and the ground-truth answer provided below.\n",
+    "\n",
+    "Your evaluation should consider helpfulness, relevance, accuracy, and level of detail.\n",
+    "\n",
+    "Rate each AI assistant's response with a single score on a scale of 1 to 10, where 10 represents excellent performance.\n",
+    "\n",
+    "Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\n",
+    "\n",
+    "In the subsequent line, provide a concise explanation of your evaluation.\n",
+    "\n",
+    "Avoid any potential bias and ensure that the order in which the responses were presented does not affect your judgment.\"\"\""
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "0a0d8d6c",
+   "metadata": {},
+   "source": [
+    "def generate_user_content(q, a, a_hat_1, a_hat_2):\n",
+    "    return f\"\"\"[Question]\n",
+    "{q}\n",
+    "[End of Question]\n",
+    "    \n",
+    "[Ground-truth Answer]\n",
+    "{a}\n",
+    "[End of Ground-truth Answer]\n",
+    "\n",
+    "[Assistant 1 Answer]\n",
+    "{a_hat_1}\n",
+    "[End of Assistant 1 Answer]\n",
+    "\n",
+    "[Assistant 2 Answer]\n",
+    "{a_hat_2}\n",
+    "[End of Assistant 2 Answer]\"\"\""
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "f064720f",
+   "metadata": {},
+   "source": [
+    "prompts = {}\n",
+    "for _, data in answer.iterrows():\n",
+    "    messages = [{\"role\": \"system\", \"content\": system_content},\n",
+    "                {\"role\": \"user\", \"content\": generate_user_content(data.q, data.a, data.a_hat_x, data.a_hat_y)}]\n",
+    "    prompts[(data.source, data.hadm_id)] = messages\n",
+    "len(prompts)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "59cf2588",
+   "metadata": {},
+   "source": [
+    "import asyncio\n",
+    "from openai import AsyncAzureOpenAI\n",
+    "\n",
+    "# TODO: Enter your credentials\n",
+    "async_client = AsyncAzureOpenAI(\n",
+    "    azure_endpoint=\"\",\n",
+    "    api_key=\"\",\n",
+    "    api_version=\"\"\n",
+    ")"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "6d3b1c82",
+   "metadata": {},
+   "source": [
+    "async def generate_chat_response(async_client, prompt):\n",
+    "    chat_params = {\n",
+    "        \"model\": \"gpt-3.5-turbo\",\n",
+    "        \"messages\": prompt,\n",
+    "        \"max_tokens\": 512,\n",
+    "        \"temperature\": 0.0,\n",
+    "    }\n",
+    "    try:\n",
+    "        response = await async_client.chat.completions.create(**chat_params)\n",
+    "    except Exception as e:\n",
+    "        print(f\"Error in call_async: {e}\")\n",
+    "        time.sleep(10)\n",
+    "        print(f\"Sleep for 10s...\")\n",
+    "        return -1\n",
+    "    return response.choices[0].message.content"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "b4ebea20",
+   "metadata": {},
+   "source": [
+    "import time\n",
+    "\n",
+    "\n",
+    "async def process_prompts(prompts):\n",
+    "    # Gather all the futures together and wait for them to complete\n",
+    "    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))\n",
+    "    return responses"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "aae93763",
+   "metadata": {},
+   "source": [
+    "def chunk_list(lst, chunk_size):\n",
+    "    \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
+    "    for i in range(0, len(lst), chunk_size):\n",
+    "        yield lst[i:i + chunk_size]"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "d4ff8a13",
+   "metadata": {},
+   "source": [
+    "from tqdm.asyncio import tqdm\n",
+    "\n",
+    "\n",
+    "async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
+    "    all_responses = {}\n",
+    "\n",
+    "    for i in range(repeat):\n",
+    "\n",
+    "        print(f\"round {i}\")\n",
+    "        prev_n_responses = len(all_responses)\n",
+    "\n",
+    "        prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
+    "\n",
+    "        # Chunk the prompts into batches\n",
+    "        prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
+    "\n",
+    "        for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
+    "            batch_v = [prompts[k] for k in batch_k]\n",
+    "            responses = await process_prompts(batch_v)\n",
+    "            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
+    "        print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
+    "\n",
+    "    return all_responses"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "dfac8357",
+   "metadata": {},
+   "source": [
+    "# Choose an appropriate batch size\n",
+    "batch_size = 10  # Adjust based on your system and API limits\n",
+    "\n",
+    "# Assuming we are in an async environment\n",
+    "responses = await process_prompts_in_batches(prompts, batch_size)\n",
+    "print(f\"Processed {len(responses)} responses\")"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "36317063",
+   "metadata": {},
+   "source": [
+    "def split_responase(r, verbose=False):\n",
+    "    if verbose:\n",
+    "        print(r)\n",
+    "    split_text = r.split(\"\\n\", 1)\n",
+    "    scores = split_text[0].split(\" \")\n",
+    "    base_score = float(scores[0])\n",
+    "    score = float(scores[1])\n",
+    "    comment = split_text[1].strip() if len(split_text) > 1 else \"\"\n",
+    "    if verbose:\n",
+    "        print(\"scores:\", scores)\n",
+    "        print(\"comment:\", comment)\n",
+    "    return base_score, score, comment"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "f910eb36",
+   "metadata": {},
+   "source": [
+    "responses_split = {}\n",
+    "for k, r in responses.items():\n",
+    "    responses_split[k] = split_responase(r)"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "cell_type": "code",
+   "id": "7e65eb22",
+   "metadata": {},
+   "source": [
+    "import json\n",
+    "\n",
+    "with open(os.path.join(model_path, f\"{answer_filename}/qa_output/answer_eval.jsonl\"), \"w\") as file:\n",
+    "    c = 0\n",
+    "    for _, data in answer.iterrows():\n",
+    "        if (data.source, data.hadm_id) in responses_split:\n",
+    "            base_score, score, comment = responses_split[(data.source, data.hadm_id)]\n",
+    "            json_string = json.dumps({\n",
+    "                \"hadm_id\": data.hadm_id,\n",
+    "                \"q\": data.q,\n",
+    "                \"a\": data.a,\n",
+    "                \"a_hat\": data.a_hat_y,\n",
+    "                \"score\": score,\n",
+    "                \"base_a_hat\": data.a_hat_x,\n",
+    "                \"base_score\": base_score,\n",
+    "                \"comment\": comment,\n",
+    "                \"source\": data.source\n",
+    "            })\n",
+    "            file.write(json_string + '\\n')\n",
+    "            c += 1\n",
+    "c"
+   ],
+   "outputs": [],
+   "execution_count": null
+  },
+  {
+   "metadata": {},
+   "cell_type": "code",
+   "outputs": [],
+   "execution_count": null,
+   "source": "",
+   "id": "41cde3512e408fbc"
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "llm",
+   "language": "python",
+   "name": "llm"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.19"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}