471 lines (470 with data), 12.7 kB
{
"cells": [
{
"cell_type": "code",
"id": "afdeba94",
"metadata": {},
"source": [
"import os\n",
"import sys\n",
"\n",
"src_path = os.path.abspath(\"../..\")\n",
"print(src_path)\n",
"sys.path.append(src_path)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "0d5f2e19",
"metadata": {},
"source": "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed, remote_project_path",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e00815d2",
"metadata": {},
"source": [
"set_seed(seed=42)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "fd92d900",
"metadata": {},
"source": [
"import pandas as pd"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "model_path = os.path.join(remote_project_path, \"output\")",
"id": "4e04fa3e4c08145",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "ef32981d",
"metadata": {},
"source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "539a6392",
"metadata": {},
"source": [
"cohort = pd.read_csv(os.path.join(output_path, \"cohort_test_subset.csv\"))\n",
"print(cohort.shape)\n",
"cohort.head()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "6659ecf8",
"metadata": {},
"source": [
"hadm_ids = set(cohort.hadm_id.unique().tolist())\n",
"len(hadm_ids)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "4090a70d",
"metadata": {},
"source": [
"qa = pd.read_csv(os.path.join(output_path, \"qa_test_subset.csv\"))\n",
"qa[\"source\"] = qa.event_type.apply(lambda x: \"note\" if pd.isna(x) else \"event\")\n",
"qa"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "dc67f44e",
"metadata": {},
"source": [
"def get_events(hadm_id):\n",
" df = pd.read_csv(os.path.join(output_path, f\"event_selected/event_{hadm_id}.csv\")) \n",
" text = []\n",
" for i, row in df.iterrows():\n",
" text.append(f\"{row.timestamp:.2f} hour, {row.event_type}, {row.event_value}\")\n",
" return \"\\n\".join(text)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "1f080932",
"metadata": {},
"source": [
"print(get_events(qa.iloc[2].hadm_id))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "53830063",
"metadata": {},
"source": [
"system_content = \"\"\"You are an AI assistant specialized in analyzing ICU patient data.\n",
"You are given a sequence of clinical events from an ICU patient's hospital admission.\n",
"Each event is formatted as follows: {time elapsed after admission (in hours)}, {event type}, {event value}.\n",
"Based on this sequence of events, provide a concise and accurate answer to the question below.\n",
"Keep your response within 256 tokens.\"\"\""
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3fd4ea61",
"metadata": {},
"source": [
"messages = [{\"role\": \"system\", \"content\": system_content},\n",
" {\"role\": \"user\", \"content\": f\"{qa.iloc[0].q}\\n\\n\" + get_events(qa.iloc[0].hadm_id)}]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "dfd9801f",
"metadata": {},
"source": [
"print(messages[0][\"content\"])"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "248cd830",
"metadata": {},
"source": [
"print(messages[1][\"content\"])"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "f064720f",
"metadata": {},
"source": [
"prompts = {}\n",
"for _, data in qa.iterrows():\n",
" messages = [{\"role\": \"system\", \"content\": system_content},\n",
" {\"role\": \"user\", \"content\": f\"{data.q}\\n\\n\" + get_events(data.hadm_id)}]\n",
" prompts[(data.source, data.hadm_id)] = messages\n",
"len(prompts)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "3c1734d5",
"metadata": {},
"source": [
"prompts[(\"note\", qa.iloc[0].hadm_id)]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "52c4c16f",
"metadata": {},
"source": [
"import tiktoken\n",
"\n",
"\n",
"def num_tokens_from_message(message):\n",
" encoding = tiktoken.encoding_for_model(\"gpt-4\")\n",
" return len(encoding.encode(message[0][\"content\"])) + len(encoding.encode(message[1][\"content\"])) + 11 "
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e90c113b",
"metadata": {},
"source": [
"num_tokens_from_message(messages)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "b42067ab",
"metadata": {},
"source": [
"prompts_num_tokens = {}\n",
"for k, v in prompts.items():\n",
" prompts_num_tokens[k] = num_tokens_from_message(v)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "d80f78b7",
"metadata": {},
"source": [
"import numpy as np\n",
"\n",
"\n",
"print(\"mean: \", np.mean(list(prompts_num_tokens.values())))\n",
"print(\"std: \", np.std(list(prompts_num_tokens.values())))\n",
"print(\"min: \", np.min(list(prompts_num_tokens.values())))\n",
"print(\"max: \", np.max(list(prompts_num_tokens.values())))\n",
"print(\"25th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 25))\n",
"print(\"50th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 50))\n",
"print(\"75th Quantile: \", np.percentile(list(prompts_num_tokens.values()), 75))"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "40ea7c46",
"metadata": {},
"source": [
"max_response_tokens = 256\n",
"token_limit = 128000"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "b99ad70e",
"metadata": {},
"source": [
"import copy\n",
"\n",
"\n",
"def trim_message(message):\n",
" trimmed_message = copy.deepcopy(message)\n",
" encoding = tiktoken.encoding_for_model(\"gpt-4\")\n",
" system_tokens = len(encoding.encode(message[0][\"content\"]))\n",
" user_tokens = len(encoding.encode(message[1][\"content\"]))\n",
" \n",
" # If the total tokens are within the limit, no trimming is needed\n",
" if system_tokens + user_tokens + 11 + max_response_tokens <= token_limit:\n",
" return trimmed_message\n",
" \n",
" # Otherwise, trim the user message content\n",
" available_tokens = token_limit - system_tokens - 11 - max_response_tokens\n",
" trimmed_user_content = encoding.decode(encoding.encode(message[1][\"content\"])[:available_tokens])\n",
" \n",
" # Update the message with the trimmed content\n",
" trimmed_message[1][\"content\"] = trimmed_user_content\n",
" return trimmed_message"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "c836f4cf",
"metadata": {},
"source": [
"trimmed_prompts = {}\n",
"for k, v in prompts.items():\n",
" trimmed_v = trim_message(v)\n",
" if trimmed_v != v:\n",
" print(f\"{k} is trimmed\")\n",
" trimmed_prompts[k] = trim_message(v)\n",
"len(trimmed_prompts)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "59cf2588",
"metadata": {},
"source": [
"import asyncio\n",
"from openai import AsyncAzureOpenAI\n",
"\n",
"\n",
"# TODO: Enter your credentials\n",
"async_client = AsyncAzureOpenAI(\n",
" azure_endpoint=\"\",\n",
" api_key=\"\",\n",
" api_version=\"\"\n",
")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "6d3b1c82",
"metadata": {},
"source": [
"async def generate_chat_response(async_client, prompt):\n",
" chat_params = {\n",
" \"model\": \"gpt-4\",\n",
" \"messages\": prompt,\n",
" \"max_tokens\": max_response_tokens,\n",
" \"temperature\": 0.0,\n",
" }\n",
" try:\n",
" response = await async_client.chat.completions.create(**chat_params)\n",
" except Exception as e:\n",
" print(f\"Error in call_async: {e}\")\n",
" time.sleep(10)\n",
" print(f\"Sleep for 10s...\")\n",
" return -1\n",
" return response.choices[0].message.content"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "b4ebea20",
"metadata": {},
"source": [
"import time\n",
"\n",
"\n",
"async def process_prompts(prompts):\n",
" # Gather all the futures together and wait for them to complete\n",
" responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts)) \n",
" return responses"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "aae93763",
"metadata": {},
"source": [
"def chunk_list(lst, chunk_size):\n",
" \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
" for i in range(0, len(lst), chunk_size):\n",
" yield lst[i:i + chunk_size]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "d4ff8a13",
"metadata": {},
"source": [
"from tqdm.asyncio import tqdm\n",
"\n",
"\n",
"async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
" all_responses = {}\n",
" \n",
" for i in range(repeat):\n",
" \n",
" print(f\"round {i}\")\n",
" prev_n_responses = len(all_responses)\n",
" \n",
" prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
"\n",
" # Chunk the prompts into batches\n",
" prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
"\n",
" for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
" batch_v = [prompts[k] for k in batch_k]\n",
" responses = await process_prompts(batch_v)\n",
" all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
" print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
" \n",
" return all_responses"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "dfac8357",
"metadata": {},
"source": [
"# Choose an appropriate batch size\n",
"batch_size = 10 # Adjust based on your system and API limits\n",
"\n",
"# Assuming we are in an async environment\n",
"responses = await process_prompts_in_batches(trimmed_prompts, batch_size)\n",
"print(f\"Processed {len(responses)} responses\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "7e65eb22",
"metadata": {},
"source": [
"import json\n",
"\n",
"\n",
"with open(os.path.join(model_path, \"gpt4/qa_output/answer.jsonl\"), \"w\") as file:\n",
" for _, data in qa.iterrows():\n",
" a_hat = responses.get((data.source, data.hadm_id), \"\")\n",
" json_string = json.dumps({\"hadm_id\": data.hadm_id, \"q\": data.q, \"a\": data.a, \"a_hat\": a_hat, \"source\": data.source})\n",
" file.write(json_string + '\\n')"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "code",
"id": "e4424b6a",
"metadata": {},
"source": [],
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "llm",
"language": "python",
"name": "llm"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}