[7e250a]: / src / gpt / gpt.ipynb

Download this file

801 lines (800 with data), 98.7 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from openai import OpenAI\n",
    "import tiktoken\n",
    "\n",
    "from pytrial.data.demo_data import load_trial_outcome_data\n",
    "from pytrial.data.trial_data import TrialOutcomeDatasetBase\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import *\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import requests\n",
    "from dotenv import load_dotenv\n",
    "import traceback\n",
    "from typing import Iterable\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "from datetime import datetime\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_length(model, text):\n",
    "    assert model == \"gpt-3.5-turbo\" or model == \"gpt-4-0125-preview\"\n",
    "    encoder = tiktoken.encoding_for_model(model)\n",
    "    token_cnt = len(encoder.encode(text))\n",
    "    return token_cnt\n",
    "\n",
    "def get_pricing(direction, model, text):\n",
    "    assert (model == \"gpt-3.5-turbo\" or model == \"gpt-4-0125-preview\") and (direction == \"in\" or direction == \"out\")\n",
    "    token_cnt = get_length(model, text)\n",
    "    if direction == \"in\":\n",
    "        return (token_cnt * 0.01 / 1e3 if model == \"gpt-4-0125-preview\" else token_cnt * 0.0005 / 1e3)\n",
    "    else:\n",
    "        return (token_cnt * 0.03 / 1e3 if model == \"gpt-4-0125-preview\" else token_cnt * 0.0015 / 1e3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv()\n",
    "assert os.getenv(\"OPENAI_API_KEY\") is not None\n",
    "\n",
    "client = OpenAI()\n",
    "\n",
    "def get_assistant_msg(content: str):\n",
    "    return {\"role\": \"assistant\",\n",
    "            \"content\": content}\n",
    "\n",
    "def pipe_qa_chat(trial: str, base_messages: list, summary_prompt: str, model: str, tool_data: dict, questions = None):\n",
    "    \n",
    "    trial_message = {\"role\": \"user\",\n",
    "                     \"content\": f\"The trial data is the following.\\n{trial}\\nLet's explain this clinical trial for predicting its success. Let's examine, step by step.\"}\n",
    "    \n",
    "    messages = base_messages.copy()\n",
    "    messages.append(trial_message)\n",
    "    history = messages.copy()\n",
    "    response = get_assistant_msg(get_response(history, model, None))\n",
    "    history.append(response)\n",
    "\n",
    "    question_responses = []\n",
    "    for question in questions:\n",
    "        question_responses.append(question)\n",
    "        response = get_assistant_msg(get_response(\n",
    "            messages + [question], model, None))\n",
    "        question_responses.append(response)\n",
    "\n",
    "    history += question_responses\n",
    "\n",
    "    summary_message = {\"role\": \"user\",\n",
    "                       \"content\": f\"{summary_prompt} With that in mind, predict the outcome of this clinical trial. Let's examine, step by step.\"}\n",
    "    \n",
    "    history.append(summary_message)\n",
    "    response = get_response(history, model, tool_data)\n",
    "\n",
    "    return response\n",
    "\n",
    "\n",
    "def get_response(messages: list, model: str, tool_data: dict):\n",
    "    # print(len(messages))\n",
    "    # return {\"reasoning\": \"\", \"prediction\": \"\"}\n",
    "\n",
    "    if tool_data is not None:\n",
    "        tools = [\n",
    "            {\n",
    "                \"type\": \"function\",\n",
    "                \"function\": {\n",
    "                    \"name\": \"predict_trial_outcome\",\n",
    "                    \"description\": tool_data.get(\"description\"),\n",
    "                    \"parameters\": tool_data.get(\"parameters\")\n",
    "                }\n",
    "            }\n",
    "        ]\n",
    "    \n",
    "        completion = client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=messages,\n",
    "            tools=tools,\n",
    "            tool_choice=\"auto\",\n",
    "            temperature=0,\n",
    "        )\n",
    "\n",
    "    else:\n",
    "        completion = client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=messages\n",
    "        )\n",
    "\n",
    "    assert completion.choices is not None\n",
    "\n",
    "    message = completion.choices[0].message\n",
    "    if message.tool_calls is not None:\n",
    "        tool_calls = message.tool_calls\n",
    "        response = tool_calls[0].function.arguments\n",
    "    else:\n",
    "        response = message.content\n",
    "\n",
    "    if response is None:\n",
    "        raise ValueError(f\"Invalid response structure: {completion}\")\n",
    "\n",
    "    return response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_folder = \"../../results/gpt\"\n",
    "results_file = \"gpt4-phase3-subtasks+clinical_knowledge+augmented_data.json\"\n",
    "model = \"gpt-4-0125-preview\"\n",
    "limit = 30"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def serialize_row(row: pd.Series) -> str:\n",
    "    text_values = []\n",
    "    for colname, value in row.items():\n",
    "        text_values.append(f\"{colname} is {value if value is not None else 'unknown'}\")\n",
    "\n",
    "    return \", \".join(text_values) + \".\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_nct_data(nct_ids: list[str]):\n",
    "    url_prefix = \"https://clinicaltrials.gov/api/v2/studies\"\n",
    "    headers = {\n",
    "        \"accept\": \"application/json\",\n",
    "    }\n",
    "\n",
    "    query_enc_ids = [\"%22\" + nct_id + \"%22\" for nct_id  in nct_ids]\n",
    "    query = \"?query.id=\" + \"+OR+\".join(query_enc_ids) + \"&pageSize=200\"\n",
    "\n",
    "    try:\n",
    "        response = requests.get(f\"{url_prefix}{query}\", headers=headers)\n",
    "\n",
    "        if response.status_code == 200:\n",
    "            trial_data = response.json()\n",
    "            trials = []\n",
    "\n",
    "            print(f\"Study count: {len(trial_data['studies'])}\")\n",
    "\n",
    "            for study in trial_data['studies']:\n",
    "                identificationModule = study.get('protocolSection', {}).get('identificationModule', {})\n",
    "                sponsorCollaboratorsModule = study.get('protocolSection', {}).get('sponsorCollaboratorsModule', {})\n",
    "                descriptionModule = study.get('protocolSection', {}).get('descriptionModule', {})\n",
    "                designModule = study.get('protocolSection', {}).get('designModule', {})\n",
    "                contactsLocationsModule = study.get('protocolSection', {}).get('contactsLocationsModule', {})\n",
    "\n",
    "                extracted_trial = {\n",
    "                    \"nctId\": identificationModule.get(\"nctId\"),\n",
    "                    \"description\": descriptionModule.get(\"briefSummary\"),\n",
    "                    \"lead_sponsor\": sponsorCollaboratorsModule.get(\"leadSponsor\", {}).get(\"name\"),\n",
    "                    \"collaborators\": [collaborator.get(\"name\") for collaborator in sponsorCollaboratorsModule.get(\"collaborators\", [])],\n",
    "                    \"study_type\": designModule.get(\"studyType\"),\n",
    "                    \"location_count\": len(contactsLocationsModule.get(\"locations\", [])),\n",
    "                }\n",
    "                trials.append(extracted_trial)\n",
    "            return trials\n",
    "        else:\n",
    "            print(f\"Request failed with status code: {response.status_code}\")\n",
    "            return None\n",
    "    except requests.RequestException as e:\n",
    "        print(f\"Request exception occurred: {e}\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Study count: 30\n"
     ]
    }
   ],
   "source": [
    "df = load_trial_outcome_data(phase='III', split='test')['data']\n",
    "\n",
    "test_data = TrialOutcomeDatasetBase(df).data\n",
    "\n",
    "successful_trials = test_data[test_data['label']\n",
    "                           == 1].sample(n=limit//2, random_state=42)\n",
    "failed_trials = test_data[test_data['label']\n",
    "                           == 0].sample(n=limit - limit//2, random_state=42)\n",
    "\n",
    "test_data = pd.concat(\n",
    "    [successful_trials, failed_trials]).sample(frac=1, random_state=42)\n",
    "\n",
    "nct_ids = test_data['nctid'].unique().tolist()\n",
    "augmented_data = get_nct_data(nct_ids)\n",
    "augmented_df = pd.DataFrame(augmented_data)\n",
    "test_data = test_data.merge(augmented_df, left_on='nctid', right_on='nctId', how='left')\n",
    "\n",
    "remaining_data = test_data[~test_data.index.isin(test_data.index)]\n",
    "\n",
    "outcomes = test_data[['nctid', 'label', 'why_stop', 'status']]\n",
    "test_data.drop(['smiless', 'label', 'why_stop',\n",
    "               'study_first_submitted_date', 'nctid', 'status'], axis=1, inplace=True)\n",
    "\n",
    "serialized_trials = test_data.apply(serialize_row, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>nctid</th>\n",
       "      <th>status</th>\n",
       "      <th>why_stop</th>\n",
       "      <th>label</th>\n",
       "      <th>phase</th>\n",
       "      <th>diseases</th>\n",
       "      <th>icdcodes</th>\n",
       "      <th>drugs</th>\n",
       "      <th>smiless</th>\n",
       "      <th>criteria</th>\n",
       "      <th>title</th>\n",
       "      <th>study_first_submitted_date</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>NCT01308528</td>\n",
       "      <td>completed</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['venous thromboembolism']</td>\n",
       "      <td>[\"['O88.22', 'O88.23', 'O88.211', 'O88.212', '...</td>\n",
       "      <td>['sodium enoxaparin', 'sodium enoxaparin clexa...</td>\n",
       "      <td>['[H][N]([H])([H])[Pt](Cl)(Cl)[N]([H])([H])[H]...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Prophylactic Use of Sodium Enoxaparin for Veno...</td>\n",
       "      <td>2011-03-01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NCT01670552</td>\n",
       "      <td>completed</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['acute and chronic inflammation', 'dyspepsia']</td>\n",
       "      <td>[\"['K30']\"]</td>\n",
       "      <td>['nimesulide + pantoprazole', 'naproxen + esom...</td>\n",
       "      <td>['COC1=C(OC)C(CS(=O)C2=NC3=C(N2)C=C(OC(F)F)C=C...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Evaluation of Two Therapies for the Treatment ...</td>\n",
       "      <td>2012-08-17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NCT01786824</td>\n",
       "      <td>terminated</td>\n",
       "      <td>\\n    patient pathway has become infeasible du...</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['acute kidney injury', 'renal insufficiency']</td>\n",
       "      <td>[\"['N26.2', 'Q63.0', 'Q63.2', 'Z52.4', 'I75.81...</td>\n",
       "      <td>['hydration strategy using saline', 'hydration...</td>\n",
       "      <td>['NS(=O)(=O)C1=C(Cl)C=C(NCC2=CC=CO2)C(=C1)C(O)...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Preventing Contrast-induced Nephropathy: Evalu...</td>\n",
       "      <td>2013-02-06</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NCT01877668</td>\n",
       "      <td>completed</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['psoriatic arthritis']</td>\n",
       "      <td>[\"['L40.52']\"]</td>\n",
       "      <td>['tofacitinib 5 mg bid', 'tofacitinib 10 mg bi...</td>\n",
       "      <td>['[H][C@@]1(C)CCN(C[C@]1([H])N(C)C1=NC=NC2=C1C...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Efficacy And Safety Of Tofacitinib In Psoriati...</td>\n",
       "      <td>2013-06-12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NCT01887717</td>\n",
       "      <td>terminated</td>\n",
       "      <td>\\n    poor accrual.\\n</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['hepatocellular carcinoma']</td>\n",
       "      <td>[\"['C22.0', 'C4A.9', 'C7B.1', 'C4A.0', 'C4A.31...</td>\n",
       "      <td>['sorafenib']</td>\n",
       "      <td>['CNC(=O)C1=NC=CC(OC2=CC=C(NC(=O)NC3=CC(=C(Cl)...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Efficacy Evaluation of TheraSphere to Treat In...</td>\n",
       "      <td>2013-06-03</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1141</th>\n",
       "      <td>NCT03716050</td>\n",
       "      <td>terminated</td>\n",
       "      <td>\\n    pi decision due to slow accrual\\n</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 2/phase 3</td>\n",
       "      <td>['perfusion; complications']</td>\n",
       "      <td>[\"['A36.89', 'B01.89', 'B02.9', 'B05.89', 'B06...</td>\n",
       "      <td>['nitroglycerin']</td>\n",
       "      <td>['COC(=O)C1=C(C)NC(C)=C(C1C1=CC(=CC=C1)[N+]([O...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>The Effect of Nitroglycerin Ointment, Fluoresc...</td>\n",
       "      <td>2018-10-19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1142</th>\n",
       "      <td>NCT03733301</td>\n",
       "      <td>completed</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['atopic dermatitis']</td>\n",
       "      <td>[\"['L20.89', 'L20.9']\"]</td>\n",
       "      <td>['baricitinib', 'topical corticosteroid', 'pla...</td>\n",
       "      <td>['CCS(=O)(=O)N1CC(CC#N)(C1)N1C=C(C=N1)C1=C2C=C...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>A Study of Baricitinib (LY3009104) in Combinat...</td>\n",
       "      <td>2018-10-30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1143</th>\n",
       "      <td>NCT03907072</td>\n",
       "      <td>terminated</td>\n",
       "      <td>\\n    lack of efficacy\\n</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 2/phase 3</td>\n",
       "      <td>['duchenne muscular dystrophy']</td>\n",
       "      <td>[\"['G71.01']\"]</td>\n",
       "      <td>['wve-210201 (suvodirsen)', 'placebo']</td>\n",
       "      <td>['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          1. ...</td>\n",
       "      <td>Efficacy and Safety Study of WVE-210201 (Suvod...</td>\n",
       "      <td>2019-04-05</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1144</th>\n",
       "      <td>NCT04060888</td>\n",
       "      <td>withdrawn</td>\n",
       "      <td>\\n    pre-planned ia (global study) showed lac...</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['lupus erythematosus, systemic']</td>\n",
       "      <td>[\"['M32.9', 'M32.0', 'M32.11', 'M32.12', 'M32....</td>\n",
       "      <td>['ustekinumab (approximately 6 mg/kg)', 'ustek...</td>\n",
       "      <td>['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>A Study of Ustekinumab in Chinese Participants...</td>\n",
       "      <td>2019-08-16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1145</th>\n",
       "      <td>NCT04341727</td>\n",
       "      <td>suspended</td>\n",
       "      <td>\\n    dsmb recommended study suspension slow a...</td>\n",
       "      <td>0</td>\n",
       "      <td>phase 3</td>\n",
       "      <td>['coronavirus infection']</td>\n",
       "      <td>[\"['B34.2']\"]</td>\n",
       "      <td>['hydroxychloroquine sulfate', 'azithromycin',...</td>\n",
       "      <td>['CCN(CCO)CCCC(C)NC1=C2C=CC(Cl)=CC2=NC=C1', '[...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\n\\n          -  ...</td>\n",
       "      <td>Hydroxychloroquine,Hydroxychloroquine,Azithrom...</td>\n",
       "      <td>2020-04-04</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1146 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "            nctid      status  \\\n",
       "0     NCT01308528   completed   \n",
       "1     NCT01670552   completed   \n",
       "2     NCT01786824  terminated   \n",
       "3     NCT01877668   completed   \n",
       "4     NCT01887717  terminated   \n",
       "...           ...         ...   \n",
       "1141  NCT03716050  terminated   \n",
       "1142  NCT03733301   completed   \n",
       "1143  NCT03907072  terminated   \n",
       "1144  NCT04060888   withdrawn   \n",
       "1145  NCT04341727   suspended   \n",
       "\n",
       "                                               why_stop  label  \\\n",
       "0                                                   NaN      1   \n",
       "1                                                   NaN      1   \n",
       "2     \\n    patient pathway has become infeasible du...      0   \n",
       "3                                                   NaN      1   \n",
       "4                               \\n    poor accrual.\\n        0   \n",
       "...                                                 ...    ...   \n",
       "1141          \\n    pi decision due to slow accrual\\n        0   \n",
       "1142                                                NaN      1   \n",
       "1143                         \\n    lack of efficacy\\n        0   \n",
       "1144  \\n    pre-planned ia (global study) showed lac...      0   \n",
       "1145  \\n    dsmb recommended study suspension slow a...      0   \n",
       "\n",
       "                phase                                         diseases  \\\n",
       "0             phase 3                       ['venous thromboembolism']   \n",
       "1             phase 3  ['acute and chronic inflammation', 'dyspepsia']   \n",
       "2             phase 3   ['acute kidney injury', 'renal insufficiency']   \n",
       "3             phase 3                          ['psoriatic arthritis']   \n",
       "4             phase 3                     ['hepatocellular carcinoma']   \n",
       "...               ...                                              ...   \n",
       "1141  phase 2/phase 3                     ['perfusion; complications']   \n",
       "1142          phase 3                            ['atopic dermatitis']   \n",
       "1143  phase 2/phase 3                  ['duchenne muscular dystrophy']   \n",
       "1144          phase 3                ['lupus erythematosus, systemic']   \n",
       "1145          phase 3                        ['coronavirus infection']   \n",
       "\n",
       "                                               icdcodes  \\\n",
       "0     [\"['O88.22', 'O88.23', 'O88.211', 'O88.212', '...   \n",
       "1                                           [\"['K30']\"]   \n",
       "2     [\"['N26.2', 'Q63.0', 'Q63.2', 'Z52.4', 'I75.81...   \n",
       "3                                        [\"['L40.52']\"]   \n",
       "4     [\"['C22.0', 'C4A.9', 'C7B.1', 'C4A.0', 'C4A.31...   \n",
       "...                                                 ...   \n",
       "1141  [\"['A36.89', 'B01.89', 'B02.9', 'B05.89', 'B06...   \n",
       "1142                            [\"['L20.89', 'L20.9']\"]   \n",
       "1143                                     [\"['G71.01']\"]   \n",
       "1144  [\"['M32.9', 'M32.0', 'M32.11', 'M32.12', 'M32....   \n",
       "1145                                      [\"['B34.2']\"]   \n",
       "\n",
       "                                                  drugs  \\\n",
       "0     ['sodium enoxaparin', 'sodium enoxaparin clexa...   \n",
       "1     ['nimesulide + pantoprazole', 'naproxen + esom...   \n",
       "2     ['hydration strategy using saline', 'hydration...   \n",
       "3     ['tofacitinib 5 mg bid', 'tofacitinib 10 mg bi...   \n",
       "4                                         ['sorafenib']   \n",
       "...                                                 ...   \n",
       "1141                                  ['nitroglycerin']   \n",
       "1142  ['baricitinib', 'topical corticosteroid', 'pla...   \n",
       "1143             ['wve-210201 (suvodirsen)', 'placebo']   \n",
       "1144  ['ustekinumab (approximately 6 mg/kg)', 'ustek...   \n",
       "1145  ['hydroxychloroquine sulfate', 'azithromycin',...   \n",
       "\n",
       "                                                smiless  \\\n",
       "0     ['[H][N]([H])([H])[Pt](Cl)(Cl)[N]([H])([H])[H]...   \n",
       "1     ['COC1=C(OC)C(CS(=O)C2=NC3=C(N2)C=C(OC(F)F)C=C...   \n",
       "2     ['NS(=O)(=O)C1=C(Cl)C=C(NCC2=CC=CO2)C(=C1)C(O)...   \n",
       "3     ['[H][C@@]1(C)CCN(C[C@]1([H])N(C)C1=NC=NC2=C1C...   \n",
       "4     ['CNC(=O)C1=NC=CC(OC2=CC=C(NC(=O)NC3=CC(=C(Cl)...   \n",
       "...                                                 ...   \n",
       "1141  ['COC(=O)C1=C(C)NC(C)=C(C1C1=CC(=CC=C1)[N+]([O...   \n",
       "1142  ['CCS(=O)(=O)N1CC(CC#N)(C1)N1C=C(C=N1)C1=C2C=C...   \n",
       "1143  ['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2...   \n",
       "1144  ['CN1C(=O)C=C(N2CCC[C@@H](N)C2)N(CC2=C(C=CC=C2...   \n",
       "1145  ['CCN(CCO)CCCC(C)NC1=C2C=CC(Cl)=CC2=NC=C1', '[...   \n",
       "\n",
       "                                               criteria  \\\n",
       "0     \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "1     \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "2     \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "3     \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "4     \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "...                                                 ...   \n",
       "1141  \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "1142  \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "1143  \\n        Inclusion Criteria:\\n\\n          1. ...   \n",
       "1144  \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "1145  \\n        Inclusion Criteria:\\n\\n          -  ...   \n",
       "\n",
       "                                                  title  \\\n",
       "0     Prophylactic Use of Sodium Enoxaparin for Veno...   \n",
       "1     Evaluation of Two Therapies for the Treatment ...   \n",
       "2     Preventing Contrast-induced Nephropathy: Evalu...   \n",
       "3     Efficacy And Safety Of Tofacitinib In Psoriati...   \n",
       "4     Efficacy Evaluation of TheraSphere to Treat In...   \n",
       "...                                                 ...   \n",
       "1141  The Effect of Nitroglycerin Ointment, Fluoresc...   \n",
       "1142  A Study of Baricitinib (LY3009104) in Combinat...   \n",
       "1143  Efficacy and Safety Study of WVE-210201 (Suvod...   \n",
       "1144  A Study of Ustekinumab in Chinese Participants...   \n",
       "1145  Hydroxychloroquine,Hydroxychloroquine,Azithrom...   \n",
       "\n",
       "     study_first_submitted_date  \n",
       "0                    2011-03-01  \n",
       "1                    2012-08-17  \n",
       "2                    2013-02-06  \n",
       "3                    2013-06-12  \n",
       "4                    2013-06-03  \n",
       "...                         ...  \n",
       "1141                 2018-10-19  \n",
       "1142                 2018-10-30  \n",
       "1143                 2019-04-05  \n",
       "1144                 2019-08-16  \n",
       "1145                 2020-04-04  \n",
       "\n",
       "[1146 rows x 12 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "30it [35:11, 70.37s/it]\n"
     ]
    }
   ],
   "source": [
    "with open(\"prompts.json\", \"r\") as file:\n",
    "    prompts = json.load(file)\n",
    "\n",
    "prompt = prompts['templates']['subtasks+clinical_knowledge']\n",
    "\n",
    "base_messages = prompt['base_messages']\n",
    "summary_prompt = prompt['summary_prompt']\n",
    "tool_data = prompt['tool_data']\n",
    "questions = prompt['questions']\n",
    "\n",
    "\n",
    "with open(f\"{results_folder}/{results_file}\", \"w\") as file:\n",
    "    json.dump([], file)\n",
    "\n",
    "def process_trial(trial, outcome):\n",
    "    idx, outcome = outcome\n",
    "    \n",
    "    try:\n",
    "        res = pipe_qa_chat(\n",
    "            trial, base_messages, summary_prompt, model, tool_data, questions)\n",
    "\n",
    "        response = json.loads(res) if isinstance(res, str) else res\n",
    "        response['ground_truth'] = \"success\" if outcome.get(\n",
    "            'label') == 1 else \"failure\"\n",
    "        response['nctid'] = outcome.get('nctid')\n",
    "        if str(outcome[\"why_stop\"]).lower() != \"nan\":\n",
    "            response[\"fail_reason\"] = outcome[\"why_stop\"] \n",
    "\n",
    "        with open(f\"{results_folder}/{results_file}\", \"r+\") as file:\n",
    "            file_data = json.load(file)\n",
    "            file_data.append(response)\n",
    "            file.seek(0)\n",
    "            json.dump(file_data, file, indent=4)\n",
    "            file.truncate()\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing trial {outcome.get('nctid', 'Unknown')}: {e}\")\n",
    "        log_error(e, outcome)\n",
    "\n",
    "\n",
    "def log_error(exception, outcome):\n",
    "    error_message = {\n",
    "        \"trial_id\": outcome.get('nctid', 'unknown'),\n",
    "        \"error\": str(exception),\n",
    "        \"stack_trace\": traceback.format_exc(),\n",
    "        \"time\": datetime.now().isoformat()\n",
    "    }\n",
    "\n",
    "    with open(f\"{results_folder}/error_log.json\", \"a+\") as error_file:\n",
    "        error_file.seek(0)\n",
    "        try:\n",
    "            errors = json.load(error_file)\n",
    "        except json.JSONDecodeError:\n",
    "            errors = []\n",
    "        errors.append(error_message)\n",
    "        error_file.seek(0)\n",
    "        json.dump(errors, error_file, indent=4)\n",
    "        error_file.truncate()\n",
    "\n",
    "for trial, outcome in tqdm(zip(serialized_trials[:limit], outcomes.iterrows())):\n",
    "    process_trial(trial, outcome)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Acc': 0.6, 'F1': 0.7142857142857143, 'Precision': 0.5555555555555556, 'Recall': 1.0, 'ROC-AUC': 0.6, 'PR-AUC': 0.5555555555555556, 'Specificity': 0.2}\n"
     ]
    }
   ],
   "source": [
    "preds = pd.read_json(f\"{results_folder}/{results_file}\")\n",
    "\n",
    "preds['prediction'] = preds['prediction'].apply(\n",
    "    lambda x: 'failure' if x != 'success' else 'success')\n",
    "preds['ground_truth'] = preds['ground_truth'].apply(\n",
    "    lambda x: 'failure' if x != 'success' else 'success')\n",
    "\n",
    "le = LabelEncoder()\n",
    "y_true = le.fit_transform(preds['ground_truth'])\n",
    "y_pred = le.transform(preds['prediction'])\n",
    "\n",
    "\n",
    "def calculate_metrics(y_true, y_pred):\n",
    "    metrics = {}\n",
    "    metrics['Acc'] = accuracy_score(y_true, y_pred)\n",
    "    metrics['F1'] = f1_score(y_true, y_pred)\n",
    "    metrics['Precision'] = precision_score(y_true, y_pred)\n",
    "    metrics['Recall'] = recall_score(y_true, y_pred)\n",
    "    metrics['ROC-AUC'] = roc_auc_score(y_true, y_pred)\n",
    "    metrics['PR-AUC'] = average_precision_score(y_true, y_pred)\n",
    "    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()\n",
    "    metrics['Specificity'] = tn / (tn + fp)\n",
    "\n",
    "    return metrics\n",
    "\n",
    "\n",
    "metrics = calculate_metrics(y_true, y_pred)\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Calculate confusion matrix\n",
    "cm = confusion_matrix(y_true, y_pred)\n",
    "sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')\n",
    "plt.xlabel('Predicted')\n",
    "plt.ylabel('True')\n",
    "plt.title('Confusion Matrix')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fpr, tpr, thresholds = roc_curve(y_true, y_pred)\n",
    "roc_auc = auc(fpr, tpr)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(fpr, tpr, color='darkorange', lw=2,\n",
    "         label='ROC curve (area = %0.2f)' % roc_auc)\n",
    "plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
    "plt.xlim([0.0, 1.0])\n",
    "plt.ylim([0.0, 1.05])\n",
    "plt.xlabel('False Positive Rate')\n",
    "plt.ylabel('True Positive Rate')\n",
    "plt.title('Receiver Operating Characteristic')\n",
    "plt.legend(loc=\"lower right\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}