[fefe56]: / predict_clinical_trial_outcome_using_XGBoost.ipynb

Download this file

533 lines (532 with data), 64.2 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b0df3ee9",
   "metadata": {},
   "source": [
    "# Import libraries and define helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "625ef3b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "def load_nctid2molecule_embedding_dict():\n",
    "    with open('data/nctid2molecule_embedding_dict.pkl', 'rb') as pickle_file:\n",
    "        return pickle.load(pickle_file)\n",
    "        \n",
    "def load_nctid2disease_embedding_dict():\n",
    "    with open('data/nctid2disease_embedding_dict.pkl', 'rb') as pickle_file:\n",
    "        return pickle.load(pickle_file)\n",
    "\n",
    "def load_sponsor2embedding_dict():\n",
    "    with open('data/sponsor2embedding_dict.pkl', 'rb') as pickle_file:\n",
    "        return pickle.load(pickle_file)\n",
    "\n",
    "def load_nctid2protocol_embedding_dict():\n",
    "    with open('data/nctid_2_protocol_embedding_dict.pkl', 'rb') as pickle_file:\n",
    "        return pickle.load(pickle_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63da8c43",
   "metadata": {},
   "source": [
    "# Import toy dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5b7ed7c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1028, 14) (146, 14) (295, 14)\n",
      "(1028,) (146,) (295,)\n"
     ]
    }
   ],
   "source": [
    "# Import toy dataset\n",
    "toy_df = pd.read_pickle('data/toy_df_full.pkl')\n",
    "\n",
    "train_df = toy_df[toy_df['split'] == 'train']\n",
    "val_df = toy_df[toy_df['split'] == 'valid']\n",
    "test_df = toy_df[toy_df['split'] == 'test']\n",
    "\n",
    "y_train = train_df['label']\n",
    "y_val = val_df['label']\n",
    "y_test = test_df['label']\n",
    "\n",
    "print(train_df.shape, val_df.shape, test_df.shape)\n",
    "print(y_train.shape, y_val.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fbc0292d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>nctid</th>\n",
       "      <th>phase</th>\n",
       "      <th>indications</th>\n",
       "      <th>drug_interventions</th>\n",
       "      <th>smiless</th>\n",
       "      <th>criteria</th>\n",
       "      <th>enrollment</th>\n",
       "      <th>lead_sponsor</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>NCT00000378</td>\n",
       "      <td>Phase 4</td>\n",
       "      <td>[Depression, Melancholia]</td>\n",
       "      <td>[Sertraline, Nortriptyline]</td>\n",
       "      <td>['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=C...</td>\n",
       "      <td>\\n        Inclusion Criteria:\\r\\n\\r\\n        -...</td>\n",
       "      <td>110</td>\n",
       "      <td>New York State Psychiatric Institute</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>NCT00001656</td>\n",
       "      <td>Phase 4</td>\n",
       "      <td>[Childhood Schizophrenia, Psychotic Disorder, ...</td>\n",
       "      <td>[Olanzapine, Clozapine]</td>\n",
       "      <td>['[H][C@]12[C@H](OC(=O)C3=CC=CC=C3)[C@]3(O)C[C...</td>\n",
       "      <td>\\n        -  INCLUSION CRITERIA:\\r\\n\\r\\n      ...</td>\n",
       "      <td>25</td>\n",
       "      <td>National Institute of Mental Health (NIMH)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NCT00002863</td>\n",
       "      <td>Phase 1</td>\n",
       "      <td>[Sarcoma]</td>\n",
       "      <td>[chemotherapy]</td>\n",
       "      <td>['NC1=NC(=O)N(C=C1)[C@@H]1O[C@H](CO)[C@@H](O)C...</td>\n",
       "      <td>\\n        DISEASE CHARACTERISTICS: Biopsy-prov...</td>\n",
       "      <td>19</td>\n",
       "      <td>University of Southern California</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NCT00003060</td>\n",
       "      <td>Phase 1</td>\n",
       "      <td>[Melanoma (Skin)]</td>\n",
       "      <td>[busulfan, cyclophosphamide, cyclosporine, met...</td>\n",
       "      <td>['N[C@@H](CCCNC(N)=N)C(O)=O', '[H][C@@]12C[C@H...</td>\n",
       "      <td>\\n        DISEASE CHARACTERISTICS: Biopsy prov...</td>\n",
       "      <td>6</td>\n",
       "      <td>Louisiana State University Health Sciences Cen...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NCT00003567</td>\n",
       "      <td>Phase 1</td>\n",
       "      <td>[Brain and Central Nervous System Tumors, Lymp...</td>\n",
       "      <td>[O6-benzylguanine, carmustine, temozolomide]</td>\n",
       "      <td>['N=C1NC2=C(N=CN2)C(OCC2=CC=CC=C2)=N1', 'ClCCN...</td>\n",
       "      <td>\\n        DISEASE CHARACTERISTICS:\\r\\n\\r\\n    ...</td>\n",
       "      <td>8</td>\n",
       "      <td>Case Comprehensive Cancer Center</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         nctid    phase                                        indications  \\\n",
       "0  NCT00000378  Phase 4                          [Depression, Melancholia]   \n",
       "1  NCT00001656  Phase 4  [Childhood Schizophrenia, Psychotic Disorder, ...   \n",
       "2  NCT00002863  Phase 1                                          [Sarcoma]   \n",
       "3  NCT00003060  Phase 1                                  [Melanoma (Skin)]   \n",
       "4  NCT00003567  Phase 1  [Brain and Central Nervous System Tumors, Lymp...   \n",
       "\n",
       "                                  drug_interventions  \\\n",
       "0                        [Sertraline, Nortriptyline]   \n",
       "1                            [Olanzapine, Clozapine]   \n",
       "2                                     [chemotherapy]   \n",
       "3  [busulfan, cyclophosphamide, cyclosporine, met...   \n",
       "4       [O6-benzylguanine, carmustine, temozolomide]   \n",
       "\n",
       "                                             smiless  \\\n",
       "0  ['CN[C@H]1CC[C@@H](C2=CC(Cl)=C(Cl)C=C2)C2=CC=C...   \n",
       "1  ['[H][C@]12[C@H](OC(=O)C3=CC=CC=C3)[C@]3(O)C[C...   \n",
       "2  ['NC1=NC(=O)N(C=C1)[C@@H]1O[C@H](CO)[C@@H](O)C...   \n",
       "3  ['N[C@@H](CCCNC(N)=N)C(O)=O', '[H][C@@]12C[C@H...   \n",
       "4  ['N=C1NC2=C(N=CN2)C(OCC2=CC=CC=C2)=N1', 'ClCCN...   \n",
       "\n",
       "                                            criteria enrollment  \\\n",
       "0  \\n        Inclusion Criteria:\\r\\n\\r\\n        -...        110   \n",
       "1  \\n        -  INCLUSION CRITERIA:\\r\\n\\r\\n      ...         25   \n",
       "2  \\n        DISEASE CHARACTERISTICS: Biopsy-prov...         19   \n",
       "3  \\n        DISEASE CHARACTERISTICS: Biopsy prov...          6   \n",
       "4  \\n        DISEASE CHARACTERISTICS:\\r\\n\\r\\n    ...          8   \n",
       "\n",
       "                                        lead_sponsor  \n",
       "0               New York State Psychiatric Institute  \n",
       "1         National Institute of Mental Health (NIMH)  \n",
       "2                  University of Southern California  \n",
       "3  Louisiana State University Health Sciences Cen...  \n",
       "4                   Case Comprehensive Cancer Center  "
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df[['nctid', 'phase', 'indications', 'drug_interventions', 'smiless', 'criteria', 'enrollment', 'lead_sponsor']].head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b07b3d7c",
   "metadata": {},
   "source": [
    "# Transform data into embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "efc290d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input shape:  (1028, 14)\n",
      "embedding drug molecules..\n",
      "drug molecules successfully embedded into (1028, 1024) dimensions\n",
      "embedding protocols..\n",
      "protocols successfully embedded into (1028, 624) dimensions\n",
      "embedding disease indications..\n",
      "disease indications successfully embedded into (1028, 312) dimensions\n",
      "embedding sponsors..\n",
      "sponsors successfully embedded into (1028, 384) dimensions\n",
      "normalizing enrollment numbers..\n",
      "enrollment successfully embedded into (1028, 1) dimensions\n",
      "output shape:  (1028, 2345)\n",
      "input shape:  (146, 14)\n",
      "embedding drug molecules..\n",
      "drug molecules successfully embedded into (146, 1024) dimensions\n",
      "embedding protocols..\n",
      "protocols successfully embedded into (146, 624) dimensions\n",
      "embedding disease indications..\n",
      "disease indications successfully embedded into (146, 312) dimensions\n",
      "embedding sponsors..\n",
      "sponsors successfully embedded into (146, 384) dimensions\n",
      "normalizing enrollment numbers..\n",
      "enrollment successfully embedded into (146, 1) dimensions\n",
      "output shape:  (146, 2345)\n",
      "input shape:  (295, 14)\n",
      "embedding drug molecules..\n",
      "drug molecules successfully embedded into (295, 1024) dimensions\n",
      "embedding protocols..\n",
      "protocols successfully embedded into (295, 624) dimensions\n",
      "embedding disease indications..\n",
      "disease indications successfully embedded into (295, 312) dimensions\n",
      "embedding sponsors..\n",
      "sponsors successfully embedded into (295, 384) dimensions\n",
      "normalizing enrollment numbers..\n",
      "filling 1 NaNs with median value\n",
      "succesfully filled NaNs with median value: 0 NaNs left\n",
      "enrollment successfully embedded into (295, 1) dimensions\n",
      "output shape:  (295, 2345)\n"
     ]
    }
   ],
   "source": [
    "def embed_all(df):\n",
    "    print('input shape: ', df.shape)\n",
    "    ### EMBEDDING MOLECULES ###\n",
    "    print('embedding drug molecules..')\n",
    "    nctid2molecule_embedding_dict = load_nctid2molecule_embedding_dict()\n",
    "    h_m = np.stack(df['nctid'].map(nctid2molecule_embedding_dict)) \n",
    "    print(f\"drug molecules successfully embedded into {h_m.shape} dimensions\")\n",
    "    ### EMBEDDING PROTOCOLS ###\n",
    "    print('embedding protocols..')\n",
    "    nctid2protocol_embedding_dict = load_nctid2protocol_embedding_dict()\n",
    "    h_p = np.stack(df['nctid'].map(nctid2protocol_embedding_dict))\n",
    "    print(f\"protocols successfully embedded into {h_p.shape} dimensions\")\n",
    "    ### EMBEDDING DISEASE INDICATIONS ###\n",
    "    print('embedding disease indications..')\n",
    "    nctid2disease_embedding_dict = load_nctid2disease_embedding_dict()\n",
    "    h_d = np.stack(df['nctid'].map(nctid2disease_embedding_dict))\n",
    "    print(f\"disease indications successfully embedded into {h_d.shape} dimensions\")\n",
    "    ### EMBEDDING TRIAL SPONSORS ###\n",
    "    print('embedding sponsors..')\n",
    "    sponsor2embedding_dict = load_sponsor2embedding_dict()\n",
    "    h_s = np.stack(df['lead_sponsor'].map(sponsor2embedding_dict))\n",
    "    print(f\"sponsors successfully embedded into {h_s.shape} dimensions\")\n",
    "    ### EMBEDDING ENROLLMENT ###\n",
    "    print('normalizing enrollment numbers..')\n",
    "    enrollment = pd.to_numeric(df['enrollment'] , errors='coerce')\n",
    "    if enrollment.isna().sum() != 0:\n",
    "        print(f\"filling {enrollment.isna().sum()} NaNs with median value\")\n",
    "        enrollment.fillna(int(enrollment.median()), inplace=True)\n",
    "        print(f\"succesfully filled NaNs with median value: {enrollment.isna().sum()} NaNs left\")\n",
    "    enrollment = enrollment.astype(int)\n",
    "    h_e = np.array((enrollment - enrollment.mean())/enrollment.std()).reshape(len(df),-1)\n",
    "    print(f\"enrollment successfully embedded into {h_e.shape} dimensions\")\n",
    "    ### COMBINE ALL EMBEDDINGS ###\n",
    "    embedded_df = pd.DataFrame(data=np.column_stack((h_m, h_p, h_d, h_s, h_e)))\n",
    "    print('output shape: ', embedded_df.shape)\n",
    "    return embedded_df\n",
    "\n",
    "# Embed data\n",
    "X_train = embed_all(train_df)\n",
    "X_val = embed_all(val_df)\n",
    "X_test = embed_all(test_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59830525",
   "metadata": {},
   "source": [
    "# Define evaluation metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "294d4c9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluation functions adapted from https://github.com/futianfan/clinical-trial-outcome-prediction\n",
    "from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, \\\n",
    "accuracy_score, roc_curve, precision_recall_curve\n",
    "\n",
    "def evaluation(predict_all, label_all, threshold = 0.5):\n",
    "    auc_score = roc_auc_score(label_all, predict_all)\n",
    "    figure_folder = \"figure\"\n",
    "    #### ROC-curve \n",
    "    fpr, tpr, thresholds = roc_curve(label_all, predict_all, pos_label=1)\n",
    "    #### PR-curve\n",
    "    precision, recall, thresholds = precision_recall_curve(label_all, predict_all)\n",
    "    label_all = [int(i) for i in label_all]\n",
    "    float2binary = lambda x:0 if x<threshold else 1\n",
    "    predict_all = list(map(float2binary, predict_all))\n",
    "    f1score = f1_score(label_all, predict_all)\n",
    "    prauc_score = average_precision_score(label_all, predict_all)\n",
    "    precision = precision_score(label_all, predict_all)\n",
    "    recall = recall_score(label_all, predict_all)\n",
    "    accuracy = accuracy_score(label_all, predict_all)\n",
    "    predict_1_ratio = sum(predict_all) / len(predict_all)\n",
    "    label_1_ratio = sum(label_all) / len(label_all)\n",
    "    return auc_score, f1score, prauc_score, precision, recall, accuracy, predict_1_ratio, label_1_ratio \n",
    "\n",
    "def print_results(predict_all, label_all):\n",
    "    print_num = 5\n",
    "    auc_score, f1score, prauc_score, precision, recall, accuracy, \\\n",
    "    predict_1_ratio, label_1_ratio = evaluation(predict_all, label_all, threshold = 0.5)\n",
    "    print(\"ROC AUC: \" + str(auc_score)[:print_num] + \"\\nF1: \" + str(f1score)[:print_num] \\\n",
    "         + \"\\nPR-AUC: \" + str(prauc_score)[:print_num] \\\n",
    "         + \"\\nPrecision: \" + str(precision)[:print_num] \\\n",
    "         + \"\\nrecall: \"+str(recall)[:print_num] + \"\\naccuracy: \"+str(accuracy)[:print_num] \\\n",
    "         + \"\\npredict 1 ratio: \" + str(predict_1_ratio)[:print_num] \\\n",
    "         + \"\\nlabel 1 ratio: \" + str(label_1_ratio)[:print_num])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c822eb15",
   "metadata": {},
   "source": [
    "# Train XGBoost and print results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8377f8ef",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------Results on training data:-----------\n",
      "ROC AUC: 1.0\n",
      "F1: 1.0\n",
      "PR-AUC: 1.0\n",
      "Precision: 1.0\n",
      "recall: 1.0\n",
      "accuracy: 1.0\n",
      "predict 1 ratio: 0.661\n",
      "label 1 ratio: 0.661\n",
      "-----------Results on validation data:-----------\n",
      "ROC AUC: 0.765\n",
      "F1: 0.817\n",
      "PR-AUC: 0.799\n",
      "Precision: 0.840\n",
      "recall: 0.795\n",
      "accuracy: 0.773\n",
      "predict 1 ratio: 0.602\n",
      "label 1 ratio: 0.636\n",
      "-----------Results on test data:-----------\n",
      "ROC AUC: 0.742\n",
      "F1: 0.805\n",
      "PR-AUC: 0.757\n",
      "Precision: 0.790\n",
      "recall: 0.821\n",
      "accuracy: 0.759\n",
      "predict 1 ratio: 0.630\n",
      "label 1 ratio: 0.606\n"
     ]
    }
   ],
   "source": [
    "import xgboost as xgb\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
    "\n",
    "# Create an XGBoost classifier with specified hyperparameters\n",
    "xgb_classifier = xgb.XGBClassifier(\n",
    "    learning_rate=0.1,\n",
    "    max_depth=3,\n",
    "    n_estimators=200,\n",
    "    objective='binary:logistic',  # for binary classification\n",
    "    random_state=42\n",
    ")\n",
    "\n",
    "# Train the XGBoost model\n",
    "xgb_classifier.fit(X_train, y_train)\n",
    "# Make predictions\n",
    "y_train_pred = xgb_classifier.predict(X_train)\n",
    "y_val_pred = xgb_classifier.predict(X_val)\n",
    "y_test_pred = xgb_classifier.predict(X_test)\n",
    "print('-----------Results on training data:-----------')\n",
    "print_results(y_train_pred, y_train)\n",
    "print('-----------Results on validation data:-----------')\n",
    "print_results(y_val_pred, y_val)\n",
    "print('-----------Results on test data:-----------')\n",
    "print_results(y_test_pred, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caa8d394",
   "metadata": {},
   "source": [
    "# Plot results and compare with HINT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "869d94d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 1200x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "HINT_performance = {'ROC AUC': 0.800, 'F1': 0.798, 'PR-AUC': 0.735, 'Precision': 0.758, 'recall': 0.843, 'accuracy': 0.742}\n",
    "\n",
    "# Sample data for basic and improved performance for 5 metrics\n",
    "metrics = ['auc_score', 'f1score', 'prauc_score', 'precision', 'recall', 'accuracy']\n",
    "XGB_performance = evaluation(y_test_pred, y_test, threshold = 0.5)[:-2]\n",
    "HINT_performance = list(HINT_performance.values())\n",
    "\n",
    "# Width of each bar\n",
    "bar_width = 0.35\n",
    "\n",
    "# Create an array of indices for the metrics\n",
    "indices = np.arange(len(metrics))\n",
    "\n",
    "# Create the side-by-side bar chart\n",
    "plt.figure(figsize=(12, 6))\n",
    "plt.bar(indices - bar_width/2, XGB_performance, label='XGBoost Performance', width=bar_width, color='b', alpha=0.7)\n",
    "plt.bar(indices + bar_width/2, HINT_performance, label='HINT Performance', width=bar_width, color='orange', alpha=0.7)\n",
    "\n",
    "# Customize the chart\n",
    "plt.xlabel('Metrics')\n",
    "plt.ylabel('Performance')\n",
    "plt.title('Comparison of XGBoost performance vs HINT performance')\n",
    "plt.xticks(indices, metrics)\n",
    "plt.legend()\n",
    "\n",
    "# Show the chart\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"fig/performance.png\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}