[780764]: / src / preprocess / 04_paraphrase_qa_event.ipynb

Download this file

564 lines (563 with data), 14.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "id": "bf6469fe",
   "metadata": {},
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "src_path = os.path.abspath('../..')\n",
    "print(src_path)\n",
    "sys.path.append(src_path)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fcd74d2c",
   "metadata": {},
   "source": [
    "from src.utils import create_directory, raw_data_path, processed_data_path, set_seed"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4cae42b3",
   "metadata": {},
   "source": [
    "set_seed(seed=42)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "57e86fc8",
   "metadata": {},
   "source": [
    "import pandas as pd"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "d5a4a2f7",
   "metadata": {},
   "source": "output_path = os.path.join(processed_data_path, \"mimic4\")",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "93be6ec0",
   "metadata": {},
   "source": [
    "from pandarallel import pandarallel\n",
    "\n",
    "pandarallel.initialize()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "3b1486b0",
   "metadata": {},
   "source": [
    "## qa_event"
   ]
  },
  {
   "cell_type": "code",
   "id": "6017d60f",
   "metadata": {},
   "source": [
    "qa_event_selected = pd.read_json(os.path.join(output_path, \"qa_event_selected.jsonl\"), lines = True)\n",
    "qa_event_selected"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4b62656f",
   "metadata": {},
   "source": "qa_event = qa_event_selected",
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "aa872069",
   "metadata": {},
   "source": [
    "qa_event.event_type.value_counts()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7e7874ff",
   "metadata": {},
   "source": [
    "qa_event.hadm_id.nunique()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b25b515b",
   "metadata": {},
   "source": [
    "qa_event['count'] = qa_event.groupby('hadm_id').cumcount()\n",
    "qa_event['hadm_id'] = qa_event['hadm_id'].astype(str) + '_' + qa_event['count'].astype(str)\n",
    "qa_event = qa_event.drop(columns=['count'])\n",
    "qa_event"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "f0e368a0",
   "metadata": {},
   "source": [
    "qa_event.hadm_id.nunique()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "id": "5ae96309",
   "metadata": {},
   "source": [
    "### GPT"
   ]
  },
  {
   "cell_type": "code",
   "id": "a50e9b87",
   "metadata": {},
   "source": [
    "system_content = \"\"\"You are an AI assistant with expertise in medical knowledge. \n",
    "\n",
    "Your input consists of a question-answer pair created using predefined rules. \n",
    "\n",
    "Your primary task is to rephrase both the question and the answer to introduce variety in the wording while preserving their original meanings.\n",
    "\n",
    "Objective:\n",
    "1. Rewrite the question and answer using language that mimics how a physician might phrase them in a real-world setting.\n",
    "2. Ensure the paraphrased text is grammatically correct.\n",
    "3. Adjust capitalization as needed.\n",
    "4. Maintain the original intent and meaning of the question-answer pair.\n",
    "5. Format your response as follows:\n",
    "- Question: [Your paraphrased question]\n",
    "- Answer: [Your paraphrased answer]\n",
    "6. Aim for brevity in both the question and answer.\"\"\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b94afc0f",
   "metadata": {},
   "source": [
    "def wrap_user_content(row):\n",
    "    return f\"Question: {row.q}\\nAnswer: {row.a}\""
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a86422e3",
   "metadata": {},
   "source": [
    "print(wrap_user_content(qa_event.iloc[0]))"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "3be44396",
   "metadata": {},
   "source": [
    "prompts = {}\n",
    "for _, row in qa_event.iterrows():\n",
    "    messages = [{\"role\": \"system\", \"content\": system_content},\n",
    "                {\"role\": \"user\", \"content\": wrap_user_content(row)}]\n",
    "    prompts[row.hadm_id] = messages\n",
    "len(prompts)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b8c9366c",
   "metadata": {},
   "source": [
    "prompts[\"24903681_5\"]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "488ed07b",
   "metadata": {},
   "source": [
    "max_response_tokens = 256"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "6a21023c",
   "metadata": {},
   "source": [
    "import asyncio\n",
    "from openai import AsyncAzureOpenAI\n",
    "\n",
    "\n",
    "# TODO: Enter your credentials\n",
    "async_client = AsyncAzureOpenAI(\n",
    "    azure_endpoint=\"\",\n",
    "    api_key=\"\",\n",
    "    api_version=\"\"\n",
    ")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "5c77f0e6",
   "metadata": {},
   "source": [
    "async def generate_chat_response(async_client, prompt):\n",
    "    chat_params = {\n",
    "        \"model\": \"gpt-3.5-turbo\",\n",
    "        \"messages\": prompt,\n",
    "        \"max_tokens\": max_response_tokens,\n",
    "        \"temperature\": 0.0,\n",
    "    }\n",
    "    try:\n",
    "        response = await async_client.chat.completions.create(**chat_params)\n",
    "    except Exception as e:\n",
    "        print(f\"Error in call_async: {e}\")\n",
    "        time.sleep(10)\n",
    "        print(f\"Sleep for 10s...\")\n",
    "        return -1\n",
    "    return response.choices[0].message.content"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "4aee59b7",
   "metadata": {},
   "source": [
    "import time\n",
    "\n",
    "\n",
    "async def process_prompts(prompts):\n",
    "    # Gather all the futures together and wait for them to complete\n",
    "    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))        \n",
    "    return responses"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "7ed1aa1f",
   "metadata": {},
   "source": [
    "prompts_subset = {k: prompts[k] for k in list(prompts.keys())[:10]}\n",
    "len(prompts_subset)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "e352d5df",
   "metadata": {},
   "source": [
    "def chunk_list(lst, chunk_size):\n",
    "    \"\"\"Yield successive chunk_size chunks from lst.\"\"\"\n",
    "    for i in range(0, len(lst), chunk_size):\n",
    "        yield lst[i:i + chunk_size]"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "c034d5e7",
   "metadata": {},
   "source": [
    "from tqdm.asyncio import tqdm\n",
    "\n",
    "\n",
    "async def process_prompts_in_batches(prompts, batch_size, repeat=3):\n",
    "    all_responses = {}\n",
    "    \n",
    "    for i in range(repeat):\n",
    "        \n",
    "        print(f\"round {i}\")\n",
    "        prev_n_responses = len(all_responses)\n",
    "        \n",
    "        prompts_k = [k for k in prompts.keys() if k not in all_responses]\n",
    "\n",
    "        # Chunk the prompts into batches\n",
    "        prompt_k_batches = list(chunk_list(prompts_k, batch_size))\n",
    "\n",
    "        for batch_k in tqdm(prompt_k_batches, desc=\"Processing Batches\"):\n",
    "            batch_v = [prompts[k] for k in batch_k]\n",
    "            responses = await process_prompts(batch_v)\n",
    "            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}\n",
    "        print(f\"get {len(all_responses) - prev_n_responses} new responses\")\n",
    "    \n",
    "    return all_responses"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "80eedf1e",
   "metadata": {},
   "source": [
    "# Choose an appropriate batch size\n",
    "batch_size = 10  # Adjust based on your system and API limits\n",
    "\n",
    "# Assuming we are in an async environment\n",
    "responses = await process_prompts_in_batches(prompts_subset, batch_size)\n",
    "print(f\"Processed {len(responses)} responses\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "ee178c00",
   "metadata": {},
   "source": [
    "responses"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "42ca7399",
   "metadata": {},
   "source": [
    "def split_dict_equally(input_dict, chunks=5):\n",
    "    # Calculate the size of each chunk\n",
    "    chunk_size = len(input_dict) // chunks\n",
    "    # Calculate how many items will be in the last chunk\n",
    "    last_chunk_size = chunk_size + (len(input_dict) % chunks)\n",
    "    \n",
    "    # An iterator over the items of the original dictionary\n",
    "    it = iter(input_dict)\n",
    "    \n",
    "    # This will store the list of smaller dictionaries\n",
    "    result = []\n",
    "\n",
    "    for i in range(chunks):\n",
    "        # Use a dictionary comprehension to create a smaller dictionary\n",
    "        # The last chunk will take the remaining items\n",
    "        if i < chunks - 1:\n",
    "            part_dict = {k: input_dict[k] for k in (next(it) for _ in range(chunk_size))}\n",
    "        else:\n",
    "            part_dict = {k: input_dict[k] for k in (next(it) for _ in range(last_chunk_size))}\n",
    "        result.append(part_dict)\n",
    "    \n",
    "    return result"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "909d4646",
   "metadata": {},
   "source": [
    "split_prompts = split_dict_equally(prompts, 10)\n",
    "len(split_prompts)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "b8537e61",
   "metadata": {},
   "source": [
    "sum([len(i) for i in split_prompts])"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a08d474e",
   "metadata": {},
   "source": [
    "import re\n",
    "\n",
    "\n",
    "def split_qa(qa, verbose=False):\n",
    "    if verbose:\n",
    "        print(qa)\n",
    "    pattern = r\"-?\\s*Question: (.*)\\s*-?\\s*Answer: (.*)\"\n",
    "    match = re.search(pattern, qa)\n",
    "    question = match.group(1)\n",
    "    answer = match.group(2)\n",
    "    if verbose:\n",
    "        print(\"Question:\", question)\n",
    "        print(\"Answer:\", answer)\n",
    "    return question, answer"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "dfdeb01a",
   "metadata": {},
   "source": [
    "hadm_id_to_event_type = qa_event.set_index(\"hadm_id\").to_dict()[\"event_type\"]\n",
    "len(hadm_id_to_event_type)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "1ddf3bfe",
   "metadata": {},
   "source": [
    "import pickle\n",
    "\n",
    "with open(os.path.join(output_path, \"split_prompts_cache.pkl\"), 'wb') as file:\n",
    "    pickle.dump(split_prompts, file)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "0f674207",
   "metadata": {},
   "source": [
    "import json\n",
    "\n",
    "for chunk_i, chunk_prompts in enumerate(split_prompts):\n",
    "    print(f\"Processing chunk {chunk_i} with {len(chunk_prompts)} prompts\")\n",
    "    \n",
    "    # Choose an appropriate batch size\n",
    "    batch_size = 10  # Adjust based on your system and API limits\n",
    "\n",
    "    # Assuming we are in an async environment\n",
    "    responses = await process_prompts_in_batches(chunk_prompts, batch_size)\n",
    "    print(f\"Processed {len(responses)} responses\")\n",
    "    \n",
    "    responses_split = {}\n",
    "    for hadm_id, qa in responses.items():\n",
    "        responses_split[hadm_id] = split_qa(qa)\n",
    "        \n",
    "    print(f\"After split: {len(responses_split)}\")\n",
    "    \n",
    "    filename = f\"qa_event_{chunk_i}.jsonl\"\n",
    "    print(f\"filename: {filename}\")\n",
    "    \n",
    "    with open(os.path.join(output_path, filename), \"w\") as file:\n",
    "        for hadm_id, (q, a) in responses_split.items():\n",
    "            # Convert the dictionary to a JSON string and write it to the file\n",
    "            actual_hadm_id = int(hadm_id.split(\"_\")[0])\n",
    "            json_string = json.dumps({\"hadm_id\": actual_hadm_id, \"q\": q, \"a\": a, \"event_type\": hadm_id_to_event_type[hadm_id]})\n",
    "            file.write(json_string + '\\n')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "fe98f785",
   "metadata": {},
   "source": [
    "print(\"done\")"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "a0dbfd5d",
   "metadata": {},
   "source": [
    "input_files = [f\"qa_event_{i}.jsonl\" for i in range(10)]\n",
    "output_file = f\"qa_event.jsonl\"\n",
    "\n",
    "# Open the output file in write mode\n",
    "with open(os.path.join(output_path, output_file), 'w') as outfile:\n",
    "    # Iterate over each input file\n",
    "    for input_file in input_files:\n",
    "        # Open the input file in read mode\n",
    "        with open(os.path.join(output_path, input_file), 'r') as infile:\n",
    "            # Read each line (each JSON object) in the input file\n",
    "            for line in infile:\n",
    "                # Write the JSON object to the output file\n",
    "                outfile.write(line)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "id": "de4c79d0",
   "metadata": {},
   "source": [],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch20",
   "language": "python",
   "name": "pytorch20"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}