--- a +++ b/src/eval/eval.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "afdeba94", + "metadata": {}, + "source": [ + "import os\n", + "import sys\n", + "\n", + "src_path = os.path.abspath(\"../..\")\n", + "print(src_path)\n", + "sys.path.append(src_path)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "0d5f2e19", + "metadata": {}, + "source": "from src.utils import processed_data_path, set_seed, remote_project_path", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "e00815d2", + "metadata": {}, + "source": [ + "set_seed(seed=42)" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "import pandas as pd", + "id": "2267de3ef42c4424" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "answer_filename = \"llemr_vicuna\"", + "id": "14e9cc774ec18d0f", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": "model_path = os.path.join(remote_project_path, \"output\")", + "id": "921de6717a04852e", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "ef32981d", + "metadata": {}, + "source": "output_path = os.path.join(processed_data_path, \"mimic4\")", + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "16f23fa5", + "metadata": {}, + "source": [ + "b_answer = pd.read_json(os.path.join(model_path, f\"gpt4/qa_output/answer.jsonl\"), lines=True)\n", + "b_answer.a_hat = b_answer.a_hat.replace(\"\", float(\"nan\"))\n", + "b_answer = b_answer.dropna()\n", + "b_answer" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "539a6392", + "metadata": {}, + "source": [ + "answer = pd.read_json(os.path.join(model_path, f\"{answer_filename}/qa_output/answer.jsonl\"), lines=True)\n", + "answer" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "821203ad", + "metadata": {}, + "source": [ + "answer = b_answer.merge(answer, on=[\"hadm_id\", \"q\", \"a\", \"source\"])\n", + "answer" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "53830063", + "metadata": {}, + "source": [ + "system_content = \"\"\"You are a helpful and precise assistant for evaluating the quality of responses.\n", + "\n", + "Please assess the performance of two clinical AI assistants based on the question and the ground-truth answer provided below.\n", + "\n", + "Your evaluation should consider helpfulness, relevance, accuracy, and level of detail.\n", + "\n", + "Rate each AI assistant's response with a single score on a scale of 1 to 10, where 10 represents excellent performance.\n", + "\n", + "Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\n", + "\n", + "In the subsequent line, provide a concise explanation of your evaluation.\n", + "\n", + "Avoid any potential bias and ensure that the order in which the responses were presented does not affect your judgment.\"\"\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "0a0d8d6c", + "metadata": {}, + "source": [ + "def generate_user_content(q, a, a_hat_1, a_hat_2):\n", + " return f\"\"\"[Question]\n", + "{q}\n", + "[End of Question]\n", + " \n", + "[Ground-truth Answer]\n", + "{a}\n", + "[End of Ground-truth Answer]\n", + "\n", + "[Assistant 1 Answer]\n", + "{a_hat_1}\n", + "[End of Assistant 1 Answer]\n", + "\n", + "[Assistant 2 Answer]\n", + "{a_hat_2}\n", + "[End of Assistant 2 Answer]\"\"\"" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "f064720f", + "metadata": {}, + "source": [ + "prompts = {}\n", + "for _, data in answer.iterrows():\n", + " messages = [{\"role\": \"system\", \"content\": system_content},\n", + " {\"role\": \"user\", \"content\": generate_user_content(data.q, data.a, data.a_hat_x, data.a_hat_y)}]\n", + " prompts[(data.source, data.hadm_id)] = messages\n", + "len(prompts)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "59cf2588", + "metadata": {}, + "source": [ + "import asyncio\n", + "from openai import AsyncAzureOpenAI\n", + "\n", + "# TODO: Enter your credentials\n", + "async_client = AsyncAzureOpenAI(\n", + " azure_endpoint=\"\",\n", + " api_key=\"\",\n", + " api_version=\"\"\n", + ")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "6d3b1c82", + "metadata": {}, + "source": [ + "async def generate_chat_response(async_client, prompt):\n", + " chat_params = {\n", + " \"model\": \"gpt-3.5-turbo\",\n", + " \"messages\": prompt,\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0.0,\n", + " }\n", + " try:\n", + " response = await async_client.chat.completions.create(**chat_params)\n", + " except Exception as e:\n", + " print(f\"Error in call_async: {e}\")\n", + " time.sleep(10)\n", + " print(f\"Sleep for 10s...\")\n", + " return -1\n", + " return response.choices[0].message.content" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "b4ebea20", + "metadata": {}, + "source": [ + "import time\n", + "\n", + "\n", + "async def process_prompts(prompts):\n", + " # Gather all the futures together and wait for them to complete\n", + " responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))\n", + " return responses" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "aae93763", + "metadata": {}, + "source": [ + "def chunk_list(lst, chunk_size):\n", + " \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n", + " for i in range(0, len(lst), chunk_size):\n", + " yield lst[i:i + chunk_size]" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "d4ff8a13", + "metadata": {}, + "source": [ + "from tqdm.asyncio import tqdm\n", + "\n", + "\n", + "async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n", + " all_responses = {}\n", + "\n", + " for i in range(repeat):\n", + "\n", + " print(f\"round {i}\")\n", + " prev_n_responses = len(all_responses)\n", + "\n", + " prompts_k = [k for k in prompts.keys() if k not in all_responses]\n", + "\n", + " # Chunk the prompts into batches\n", + " prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n", + "\n", + " for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n", + " batch_v = [prompts[k] for k in batch_k]\n", + " responses = await process_prompts(batch_v)\n", + " all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n", + " print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n", + "\n", + " return all_responses" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "dfac8357", + "metadata": {}, + "source": [ + "# Choose an appropriate batch size\n", + "batch_size = 10 # Adjust based on your system and API limits\n", + "\n", + "# Assuming we are in an async environment\n", + "responses = await process_prompts_in_batches(prompts, batch_size)\n", + "print(f\"Processed {len(responses)} responses\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "36317063", + "metadata": {}, + "source": [ + "def split_responase(r, verbose=False):\n", + " if verbose:\n", + " print(r)\n", + " split_text = r.split(\"\\n\", 1)\n", + " scores = split_text[0].split(\" \")\n", + " base_score = float(scores[0])\n", + " score = float(scores[1])\n", + " comment = split_text[1].strip() if len(split_text) > 1 else \"\"\n", + " if verbose:\n", + " print(\"scores:\", scores)\n", + " print(\"comment:\", comment)\n", + " return base_score, score, comment" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "f910eb36", + "metadata": {}, + "source": [ + "responses_split = {}\n", + "for k, r in responses.items():\n", + " responses_split[k] = split_responase(r)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "id": "7e65eb22", + "metadata": {}, + "source": [ + "import json\n", + "\n", + "with open(os.path.join(model_path, f\"{answer_filename}/qa_output/answer_eval.jsonl\"), \"w\") as file:\n", + " c = 0\n", + " for _, data in answer.iterrows():\n", + " if (data.source, data.hadm_id) in responses_split:\n", + " base_score, score, comment = responses_split[(data.source, data.hadm_id)]\n", + " json_string = json.dumps({\n", + " \"hadm_id\": data.hadm_id,\n", + " \"q\": data.q,\n", + " \"a\": data.a,\n", + " \"a_hat\": data.a_hat_y,\n", + " \"score\": score,\n", + " \"base_a_hat\": data.a_hat_x,\n", + " \"base_score\": base_score,\n", + " \"comment\": comment,\n", + " \"source\": data.source\n", + " })\n", + " file.write(json_string + '\\n')\n", + " c += 1\n", + "c" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "41cde3512e408fbc" + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm", + "language": "python", + "name": "llm" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}