[bc9e98]: / UQ_on_HINT.ipynb

Download this file

913 lines (912 with data), 155.3 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os\n",
    "torch.manual_seed(0) \n",
    "import warnings;warnings.filterwarnings(\"ignore\")\n",
    "from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst\n",
    "from HINT.molecule_encode import MPNN, ADMET \n",
    "from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict\n",
    "from HINT.protocol_encode import Protocol_Embedding\n",
    "from HINT.model import HINTModel \n",
    "device = torch.device(\"cuda:0\")  # cuda:0\n",
    "if not os.path.exists(\"figure\"):\n",
    "\tos.makedirs(\"figure\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# choose the phase to be explored\n",
    "base_name  = 'phase_I' # 'phase_I', 'phase_II', 'phase_III'\n",
    "datafolder = \"data\"\n",
    "train_file = os.path.join(datafolder, base_name + '_train.csv')\n",
    "valid_file = os.path.join(datafolder, base_name + '_valid.csv')\n",
    "test_file  = os.path.join(datafolder, base_name + '_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)\n",
    "admet_model_path = \"save_model/admet_model.ckpt\"\n",
    "if not os.path.exists(admet_model_path):\n",
    "\tadmet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)\n",
    "\tadmet_trainloader_lst = [i[0] for i in admet_dataloader_lst]\n",
    "\tadmet_testloader_lst = [i[1] for i in admet_dataloader_lst]\n",
    "\tadmet_model = ADMET(molecule_encoder = mpnn_model, \n",
    "\t\t\t\t\t\thighway_num=2, \n",
    "\t\t\t\t\t\tdevice = device, \n",
    "\t\t\t\t\t\tepoch=3, \n",
    "\t\t\t\t\t\tlr=5e-4, \n",
    "\t\t\t\t\t\tweight_decay=0, \n",
    "\t\t\t\t\t\tsave_name = 'admet_')\n",
    "\tadmet_model.train(admet_trainloader_lst, admet_testloader_lst)\n",
    "\ttorch.save(admet_model, admet_model_path)\n",
    "else:\n",
    "\tadmet_model = torch.load(admet_model_path)\n",
    "\tadmet_model = admet_model.to(device)\n",
    "\tadmet_model.set_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=64) \n",
    "valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) \n",
    "test_loader  = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "icdcode2ancestor_dict = build_icdcode2ancestor_dict()\n",
    "gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)\n",
    "protocol_model = Protocol_Embedding(output_dim = 50, highway_num=3, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PR-AUC   mean: 0.5738 std: 0.0158\n",
      "F1       mean: 0.5923 std: 0.0184\n",
      "ROC-AUC  mean: 0.5729 std: 0.0262\n",
      "NCT01288573 1 0.3920\n",
      "NCT01558674 0 0.5373\n",
      "NCT01609816 0 0.5609\n",
      "NCT01836289 0 0.5322\n",
      "NCT01852890 1 0.2230\n",
      "NCT01889420 0 0.6710\n",
      "NCT01933594 1 0.4409\n",
      "NCT01965600 0 0.6439\n",
      "NCT01966731 1 0.3307\n",
      "NCT01981499 0 0.6297\n",
      "NCT01988493 1 0.3146\n",
      "NCT02023424 0 0.5840\n",
      "NCT02030483 0 0.6740\n",
      "NCT02034292 0 0.6772\n",
      "NCT02037230 1 0.3541\n",
      "NCT02040870 1 0.2827\n",
      "NCT02041481 1 0.4579\n",
      "NCT02043860 0 0.6732\n",
      "NCT02044796 1 0.4635\n",
      "NCT02054442 1 0.3160\n",
      "NCT02055690 0 0.5551\n",
      "NCT02055924 0 0.5995\n",
      "NCT02058407 0 0.6758\n",
      "NCT02061449 0 0.5994\n",
      "NCT02064608 1 0.3907\n",
      "NCT02071537 0 0.5654\n",
      "NCT02073123 1 0.3866\n",
      "NCT02074839 1 0.3136\n",
      "NCT02074878 0 0.5157\n",
      "NCT02078089 0 0.5596\n",
      "NCT02078648 1 0.4167\n",
      "NCT02078960 0 0.5228\n",
      "NCT02080416 0 0.5964\n",
      "NCT02083926 1 0.4553\n",
      "NCT02092363 0 0.5880\n",
      "NCT02098629 0 0.5295\n",
      "NCT02103335 1 0.4933\n",
      "NCT02106897 1 0.4673\n",
      "NCT02109224 0 0.5625\n",
      "NCT02109341 1 0.2147\n",
      "NCT02110992 0 0.5312\n",
      "NCT02116010 0 0.6355\n",
      "NCT02116959 0 0.6746\n",
      "NCT02120339 0 0.5137\n",
      "NCT02120911 1 0.3677\n",
      "NCT02123823 0 0.5576\n",
      "NCT02133742 1 0.3398\n",
      "NCT02135640 1 0.4724\n",
      "NCT02137642 0 0.5362\n",
      "NCT02138383 1 0.2183\n",
      "NCT02138526 0 0.5531\n",
      "NCT02141828 1 0.4563\n",
      "NCT02142803 1 0.4447\n",
      "NCT02146222 1 0.4911\n",
      "NCT02154646 1 0.2204\n",
      "NCT02154776 0 0.5154\n",
      "NCT02155465 1 0.2822\n",
      "NCT02160015 0 0.5374\n",
      "NCT02166229 0 0.5326\n",
      "NCT02168270 0 0.5192\n",
      "NCT02168907 0 0.5662\n",
      "NCT02174822 1 0.4412\n",
      "NCT02176486 0 0.6554\n",
      "NCT02181660 1 0.4325\n",
      "NCT02186834 0 0.6739\n",
      "NCT02187094 0 0.6374\n",
      "NCT02189109 1 0.3980\n",
      "NCT02192697 1 0.2834\n",
      "NCT02195869 1 0.4357\n",
      "NCT02195973 1 0.4339\n",
      "NCT02200848 0 0.5808\n",
      "NCT02202993 1 0.4000\n",
      "NCT02205398 0 0.5561\n",
      "NCT02205554 1 0.4739\n",
      "NCT02209506 0 0.5479\n",
      "NCT02210182 1 0.3149\n",
      "NCT02219516 1 0.4905\n",
      "NCT02219789 1 0.4702\n",
      "NCT02221765 0 0.6173\n",
      "NCT02222441 0 0.6469\n",
      "NCT02224599 0 0.5035\n",
      "NCT02227329 0 0.6376\n",
      "NCT02229981 0 0.6187\n",
      "NCT02231658 0 0.6587\n",
      "NCT02232152 1 0.4670\n",
      "NCT02240355 0 0.6446\n",
      "NCT02244502 1 0.3466\n",
      "NCT02250118 0 0.5376\n",
      "NCT02253212 1 0.3927\n",
      "NCT02253277 0 0.5207\n",
      "NCT02254161 0 0.6442\n",
      "NCT02257177 1 0.3751\n",
      "NCT02269085 0 0.5992\n",
      "NCT02275039 1 0.3876\n",
      "NCT02275416 1 0.4859\n",
      "NCT02277197 1 0.3166\n",
      "NCT02283372 1 0.2221\n",
      "NCT02288507 0 0.6084\n",
      "NCT02291133 1 0.4017\n",
      "NCT02292173 1 0.4380\n",
      "NCT02292550 1 0.2827\n",
      "NCT02296242 1 0.3964\n",
      "NCT02299505 1 0.2827\n",
      "NCT02299518 1 0.4575\n",
      "NCT02300298 1 0.4325\n",
      "NCT02300610 1 0.4452\n",
      "NCT02300727 0 0.5494\n",
      "NCT02301104 0 0.6698\n",
      "NCT02303912 1 0.4646\n",
      "NCT02306291 1 0.4357\n",
      "NCT02309580 0 0.5995\n",
      "NCT02312102 1 0.3848\n",
      "NCT02336048 0 0.5529\n",
      "NCT02337543 1 0.4264\n",
      "NCT02339168 1 0.3871\n",
      "NCT02339324 1 0.3781\n",
      "NCT02365532 1 0.4986\n",
      "NCT02372240 0 0.6728\n",
      "NCT02373072 1 0.4338\n",
      "NCT02374047 1 0.4401\n",
      "NCT02379195 1 0.3756\n",
      "NCT02379910 1 0.2426\n",
      "NCT02382666 1 0.2233\n",
      "NCT02383511 1 0.4660\n",
      "NCT02384746 0 0.5159\n",
      "NCT02392039 0 0.6005\n",
      "NCT02393755 1 0.4271\n",
      "NCT02403271 1 0.4006\n",
      "NCT02406521 1 0.3127\n",
      "NCT02407080 1 0.2624\n",
      "NCT02411565 0 0.5573\n",
      "NCT02414503 1 0.4289\n",
      "NCT02422381 1 0.4450\n",
      "NCT02439216 1 0.4676\n",
      "NCT02442960 0 0.5213\n",
      "NCT02446964 1 0.4472\n",
      "NCT02457286 0 0.5141\n",
      "NCT02472795 1 0.4670\n",
      "NCT02481180 0 0.6779\n",
      "NCT02483871 0 0.5161\n",
      "NCT02487459 0 0.6245\n",
      "NCT02493751 1 0.1558\n",
      "NCT02499770 1 0.2826\n",
      "NCT02508246 1 0.3228\n",
      "NCT02514447 1 0.2827\n",
      "NCT02515669 1 0.4214\n",
      "NCT02529072 0 0.5394\n",
      "NCT02530476 1 0.4326\n",
      "NCT02535104 1 0.4703\n",
      "NCT02538510 1 0.3262\n",
      "NCT02543931 0 0.6765\n",
      "NCT02548468 0 0.5206\n",
      "NCT02550743 0 0.5708\n",
      "NCT02555007 1 0.2839\n",
      "NCT02565628 0 0.6793\n",
      "NCT02568683 0 0.5993\n",
      "NCT02572687 1 0.3763\n",
      "NCT02582840 0 0.6469\n",
      "NCT02583373 1 0.4787\n",
      "NCT02589145 0 0.6021\n",
      "NCT02595437 1 0.4635\n",
      "NCT02596399 1 0.4573\n",
      "NCT02597400 0 0.6915\n",
      "NCT02603952 1 0.4799\n",
      "NCT02604719 1 0.4139\n",
      "NCT02608437 1 0.3792\n",
      "NCT02611908 0 0.5532\n",
      "NCT02613221 0 0.5572\n",
      "NCT02614586 0 0.5877\n",
      "NCT02620423 1 0.3524\n",
      "NCT02624089 0 0.5365\n",
      "NCT02630368 0 0.5141\n",
      "NCT02630823 1 0.4564\n",
      "NCT02636426 0 0.5552\n",
      "NCT02639117 1 0.4377\n",
      "NCT02642965 1 0.4758\n",
      "NCT02647086 1 0.2418\n",
      "NCT02647281 1 0.4823\n",
      "NCT02649686 0 0.5376\n",
      "NCT02655952 1 0.4523\n",
      "NCT02658084 0 0.5258\n",
      "NCT02678923 0 0.6776\n",
      "NCT02683148 0 0.6621\n",
      "NCT02691871 1 0.4450\n",
      "NCT02697383 0 0.6103\n",
      "NCT02697851 0 0.6370\n",
      "NCT02704702 1 0.4082\n",
      "NCT02708680 0 0.5161\n",
      "NCT02709850 1 0.4319\n",
      "NCT02711462 0 0.6443\n",
      "NCT02716012 1 0.3569\n",
      "NCT02716805 0 0.6729\n",
      "NCT02727283 1 0.4433\n",
      "NCT02727777 0 0.5986\n",
      "NCT02729194 1 0.3131\n",
      "NCT02734160 1 0.2204\n",
      "NCT02739009 1 0.4063\n",
      "NCT02740712 0 0.5567\n",
      "NCT02743546 0 0.5913\n",
      "NCT02743780 0 0.6156\n",
      "NCT02744456 1 0.4079\n",
      "NCT02757521 0 0.5546\n",
      "NCT02762474 1 0.4418\n",
      "NCT02762617 0 0.6337\n",
      "NCT02775812 1 0.3166\n",
      "NCT02780804 1 0.4885\n",
      "NCT02800889 0 0.5982\n",
      "NCT02806817 1 0.4151\n",
      "NCT02809183 1 0.4643\n",
      "NCT02815488 0 0.6420\n",
      "NCT02815540 0 0.5054\n",
      "NCT02824055 0 0.5880\n",
      "NCT02826798 1 0.4482\n",
      "NCT02830542 1 0.4992\n",
      "NCT02834247 0 0.6281\n",
      "NCT02835833 1 0.3489\n",
      "NCT02845050 0 0.5598\n",
      "NCT02856750 0 0.5486\n",
      "NCT02864264 0 0.5229\n",
      "NCT02875678 0 0.6009\n",
      "NCT02886559 1 0.4687\n",
      "NCT02888665 1 0.4136\n",
      "NCT02909452 1 0.4482\n",
      "NCT02911597 1 0.3027\n",
      "NCT02912234 1 0.4516\n",
      "NCT02914327 0 0.5563\n",
      "NCT02915523 0 0.5447\n",
      "NCT02936206 0 0.5158\n",
      "NCT02938273 1 0.3168\n",
      "NCT02957630 1 0.4479\n",
      "NCT02958982 0 0.6677\n",
      "NCT02959619 1 0.2832\n",
      "NCT02963376 1 0.4668\n",
      "NCT02964377 1 0.4683\n",
      "NCT02966730 0 0.5995\n",
      "NCT02967731 1 0.3315\n",
      "NCT02973399 0 0.5563\n",
      "NCT02975557 0 0.5362\n",
      "NCT02985554 0 0.6401\n",
      "NCT02993471 0 0.5331\n",
      "NCT02994953 0 0.6697\n",
      "NCT03004846 0 0.5334\n",
      "NCT03014219 0 0.6465\n",
      "NCT03023527 0 0.6729\n",
      "NCT03025009 0 0.6494\n",
      "NCT03031691 0 0.5565\n",
      "NCT03054207 0 0.6375\n",
      "NCT03057509 0 0.6196\n",
      "NCT03059693 1 0.2433\n",
      "NCT03092076 0 0.6439\n",
      "NCT03106610 0 0.5821\n",
      "NCT03121716 1 0.3677\n",
      "NCT03122678 0 0.5203\n",
      "NCT03132584 0 0.5958\n",
      "NCT03135028 0 0.5243\n",
      "NCT03135899 0 0.6677\n",
      "NCT03136627 1 0.3261\n",
      "NCT03139981 1 0.2433\n",
      "NCT03140072 0 0.6678\n",
      "NCT03145948 0 0.5333\n",
      "NCT03179501 0 0.6064\n",
      "NCT03192709 1 0.3677\n",
      "NCT03248713 0 0.5377\n",
      "NCT03263533 0 0.5847\n",
      "NCT03303911 0 0.5427\n",
      "NCT03349346 0 0.5973\n",
      "NCT03369964 0 0.5987\n",
      "NCT03372603 0 0.5504\n",
      "NCT03383692 0 0.5792\n",
      "NCT03411421 0 0.6443\n",
      "NCT03422874 0 0.5725\n",
      "NCT03431610 1 0.2433\n",
      "NCT03445013 0 0.5334\n",
      "NCT03464864 1 0.4729\n",
      "NCT03558750 0 0.5997\n",
      "NCT03576508 1 0.4255\n",
      "NCT03591146 1 0.4373\n",
      "NCT03601819 0 0.5904\n",
      "NCT03605212 0 0.6719\n",
      "NCT03606707 1 0.4654\n",
      "NCT03687125 0 0.6717\n",
      "NCT03888534 0 0.5974\n",
      "NCT04042051 0 0.5249\n"
     ]
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "hint_model_path = \"save_model/\" + base_name + \".ckpt\"\n",
    "if not os.path.exists(hint_model_path):\n",
    "\tmodel = HINTModel(molecule_encoder = mpnn_model, \n",
    "\t\t\tdisease_encoder = gram_model, \n",
    "\t\t\tprotocol_encoder = protocol_model,\n",
    "\t\t\tdevice = device, \n",
    "\t\t\tglobal_embed_size = 50, \n",
    "\t\t\thighway_num_layer = 2,\n",
    "\t\t\tprefix_name = base_name, \n",
    "\t\t\tgnn_hidden_size = 50,  \n",
    "\t\t\tepoch = 15,\n",
    "\t\t\tlr = 5e-4, \n",
    "\t\t\tweight_decay = 1e-3, \n",
    "\t\t\t)\n",
    "\tmodel.init_pretrain(admet_model)\n",
    "\ttrain_output, valid_output = model.learn(train_loader, valid_loader, test_loader)\n",
    "\tnctid_all, predict_all = model.bootstrap_test(test_loader, valid_loader=valid_loader)\n",
    "\ttorch.save(model, hint_model_path)\n",
    "else:\n",
    "\tmodel = torch.load(hint_model_path)\n",
    "\tmodel.bootstrap_test(test_loader)\n",
    "\ttrain_loss, train_predict, train_label = model.test(train_loader, return_loss=True)\n",
    "\ttest_loss, test_predict, test_label = model.test(test_loader, return_loss=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score\n",
    "from scipy.optimize import brentq\n",
    "from scipy.stats import binom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5486443381180224"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = len(test_predict)\n",
    "test_predict_label = [0] * n\n",
    "for i in range(n):\n",
    "    if test_predict[i] > 0.5:\n",
    "        test_predict_label[i] = 1\n",
    "\n",
    "accuracy_score(test_predict_label, test_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.657088122605364"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n = len(train_predict)\n",
    "train_predict_label = [0] * n\n",
    "for i in range(n):\n",
    "    if train_predict[i] > 0.5:\n",
    "        train_predict_label[i] = 1\n",
    "\n",
    "accuracy_score(train_predict_label, train_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1044, 2)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smx = np.array(train_predict)\n",
    "labels = np.array(train_label)\n",
    "a = 1-smx\n",
    "smx = np.vstack((a, smx))\n",
    "smx = smx.transpose()\n",
    "smx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.554 0.3508607036019853 True\n",
      "0.538 0.30368109842562563 True\n",
      "0.538 0.35567519384109153 True\n",
      "0.534 0.4216732884759087 True\n",
      "0.534 0.3278802471240832 True\n",
      "0.534 0.35770280796317994 True\n",
      "0.534 0.2966973400692054 True\n",
      "0.534 0.3313242713914189 True\n",
      "0.534 0.3857771181375555 True\n",
      "0.534 0.3636793928357601 True\n",
      "0.534 0.337767638543699 True\n",
      "0.534 0.37325515096924666 True\n",
      "0.534 0.3725668806539195 True\n",
      "0.534 0.4408299805577162 True\n",
      "0.534 0.44835356121599995 True\n",
      "0.534 0.3215599003528717 True\n",
      "0.534 0.3698935924827509 True\n",
      "0.534 0.3265181312939036 True\n",
      "0.534 0.3892693365694524 True\n",
      "0.534 0.3410062299280009 True\n",
      "The original accuracy is:\t\t0.5486\n",
      "The empirical selective accuracy is:\tmean: 0.6848\tstd: 0.0083\n",
      "The empirical improved accuracy is:\tmean: 0.1361\tstd: 0.0083\n",
      "The fraction of data points kept is:\tmean: 0.8151\tstd: 0.0190\n",
      "PR_AUC\t mean: 0.7606\tstd: 0.0129\n",
      "F1\t mean: 0.7252\tstd: 0.0084\n",
      "ROC_AUC  mean: 0.7149\tstd: 0.0085\n"
     ]
    }
   ],
   "source": [
    "n = 200\n",
    "alpha = 0.1\n",
    "delta = 0.1\n",
    "lambdas = np.linspace(0,1,501)\n",
    "prauc = []\n",
    "f1 = []\n",
    "rocauc = []\n",
    "selective_accuracy = []\n",
    "improved_accuracy = []\n",
    "points_kept = []\n",
    "\n",
    "for i in range(20):\n",
    "    # Split the softmax scores into calibration and validation sets (save the shuffling)\n",
    "    idx = np.array([1] * n + [0] * (smx.shape[0]-n)) > 0\n",
    "    np.random.shuffle(idx)\n",
    "    cal_smx, val_smx = smx[idx,:], smx[~idx,:]\n",
    "    cal_labels, val_labels = labels[idx], labels[~idx]\n",
    "    # Yhat, Phat\n",
    "    cal_yhats = np.argmax(cal_smx, axis=1); val_yhats = np.argmax(val_smx, axis=1)\n",
    "    cal_phats = np.max(cal_smx, axis=1); val_phats = np.max(val_smx, axis=1)\n",
    "\n",
    "\n",
    "    # Define selective risk\n",
    "    def selective_risk(lam): \n",
    "        return (cal_yhats[cal_phats >= lam] != cal_labels[cal_phats >= lam]).sum()/(cal_phats >= lam).sum()\n",
    "    def nlambda(lam): \n",
    "        return (cal_phats > lam).sum()\n",
    "    def invert_for_ub(r,lam): \n",
    "        return binom.cdf(selective_risk(lam)*nlambda(lam),nlambda(lam),r)-delta\n",
    "    # Construct upper boud\n",
    "    def selective_risk_ub(lam): \n",
    "        return brentq(invert_for_ub,0.1,0.9999,args=(lam,))\n",
    "    # Make sure there's some data in the top bin.\n",
    "    lambdas = np.array([lam for lam in lambdas if nlambda(lam) >= 150]) \n",
    "    # Scan to choose lamabda hat\n",
    "    for lhat in np.flip(lambdas):\n",
    "        # print(lhat)\n",
    "        print(lhat, selective_risk_ub(lhat-1/lambdas.shape[0]), selective_risk_ub(lhat-1/lambdas.shape[0]) > alpha)\n",
    "        if selective_risk_ub(lhat-1/lambdas.shape[0]) > alpha: break\n",
    "    # Deploy procedure on test data\n",
    "    predictions_kept = val_phats >= lhat\n",
    "\n",
    "    # Calculate empirical selective accuracy\n",
    "    test_accuracy = accuracy_score(test_predict_label, test_label)\n",
    "    empirical_selective_accuracy = (val_yhats[predictions_kept] == val_labels[predictions_kept]).mean()\n",
    "    empirical_improved_accuracy = empirical_selective_accuracy - test_accuracy\n",
    "    fraction_kept = predictions_kept.mean()\n",
    "\n",
    "    val_label_kept = val_labels[predictions_kept]\n",
    "    val_pred_kept = val_yhats[predictions_kept]\n",
    "    val_score_kept = val_smx[:,1][predictions_kept]\n",
    "\n",
    "    prauc_score = average_precision_score(val_label_kept, val_score_kept)\n",
    "    f1score = f1_score(val_label_kept, val_pred_kept)\n",
    "    auc_score = roc_auc_score(val_label_kept, val_score_kept)\n",
    "    precision = precision_score(val_label_kept, val_pred_kept)\n",
    "    recall = recall_score(val_label_kept, val_pred_kept)\n",
    "    accuracy = accuracy_score(val_label_kept, val_pred_kept)\n",
    "    predict_1_ratio = sum(val_pred_kept) / len(val_pred_kept)\n",
    "    label_1_ratio = sum(val_label_kept) / len(val_label_kept)\n",
    "\n",
    "    prauc.append(prauc_score)\n",
    "    f1.append(f1score)\n",
    "    rocauc.append(auc_score)\n",
    "    selective_accuracy.append(empirical_selective_accuracy)\n",
    "    improved_accuracy.append(empirical_improved_accuracy)\n",
    "    points_kept.append(fraction_kept)\n",
    "\n",
    "print(f\"The original accuracy is:\\t\\t{test_accuracy:.4F}\")\n",
    "print(f\"The empirical selective accuracy is:\\tmean: {np.mean(selective_accuracy):.4F}\\tstd: {np.std(selective_accuracy):.4f}\")\n",
    "print(f\"The empirical improved accuracy is:\\tmean: {np.mean(improved_accuracy):.4F}\\tstd: {np.std(improved_accuracy):.4f}\")\n",
    "print(f\"The fraction of data points kept is:\\tmean: {np.mean(points_kept):.4F}\\tstd: {np.std(points_kept):.4f}\")\n",
    "\n",
    "print(f\"PR_AUC\\t mean: {np.mean(prauc):.4F}\\tstd: {np.std(prauc):.4f}\")\n",
    "print(f\"F1\\t mean: {np.mean(f1):.4F}\\tstd: {np.std(f1):.4f}\")\n",
    "print(f\"ROC_AUC  mean: {np.mean(rocauc):.4F}\\tstd: {np.std(rocauc):.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 600x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.set_theme(style='white')\n",
    "\n",
    "origin = [0.5765, 0.6003, 0.5723, accuracy_score(test_predict_label, test_label)]\n",
    "cp = [np.mean(prauc), np.mean(f1), np.mean(rocauc), np.mean(selective_accuracy)]\n",
    "\n",
    "# X axis\n",
    "X_labels = [\"PR_AUC\", \"F1\", \"ROC_AUC\", \"Accuracy\"]\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "data = {\n",
    "    'Metric': X_labels * 2,  \n",
    "    'Value': origin + cp,  \n",
    "    'Type': ['origin']*4 + ['cp']*4  \n",
    "}\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "ax = sns.barplot(x='Metric', y='Value', hue='Type', data=df)\n",
    "\n",
    "for p in ax.patches:\n",
    "    ax.annotate(f'{p.get_height():.4f}', \n",
    "                (p.get_x() + p.get_width() / 2., p.get_height()), \n",
    "                ha = 'center', \n",
    "                va = 'center', \n",
    "                xytext = (0, 9), \n",
    "                textcoords = 'offset points',\n",
    "                fontsize = 10)\n",
    "ax.legend(loc='upper left', bbox_to_anchor=(1, 1))\n",
    "plt.title('Phase I')\n",
    "plt.ylabel('Value')\n",
    "plt.xlabel('Metric')\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "# 显示图形\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split the softmax scores into calibration and validation sets (save the shuffling)\n",
    "idx = np.array([1] * n + [0] * (smx.shape[0]-n)) > 0\n",
    "np.random.shuffle(idx)\n",
    "cal_smx, val_smx = smx[idx,:], smx[~idx,:]\n",
    "cal_labels, val_labels = labels[idx], labels[~idx]\n",
    "# Yhat, Phat\n",
    "cal_yhats = np.argmax(cal_smx, axis=1); val_yhats = np.argmax(val_smx, axis=1)\n",
    "cal_phats = np.max(cal_smx, axis=1); val_phats = np.max(val_smx, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.534"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define selective risk\n",
    "def selective_risk(lam): \n",
    "    return (cal_yhats[cal_phats >= lam] != cal_labels[cal_phats >= lam]).sum()/(cal_phats >= lam).sum()\n",
    "def nlambda(lam): \n",
    "    return (cal_phats > lam).sum()\n",
    "def invert_for_ub(r,lam): \n",
    "    return binom.cdf(selective_risk(lam)*nlambda(lam),nlambda(lam),r)-delta\n",
    "# Construct upper boud\n",
    "def selective_risk_ub(lam): \n",
    "    return brentq(invert_for_ub,0.1,0.9999,args=(lam,))\n",
    "# Make sure there's some data in the top bin.\n",
    "lambdas = np.array([lam for lam in lambdas if nlambda(lam) >= 50]) \n",
    "# Scan to choose lamabda hat\n",
    "for lhat in np.flip(lambdas):\n",
    "    # print(lhat)\n",
    "    if selective_risk_ub(lhat-1/lambdas.shape[0]) > alpha: break\n",
    "# Deploy procedure on test data\n",
    "predictions_kept = val_phats >= lhat\n",
    "lhat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nctid: NCT02565628 | Probability: 0.5692 | Prediction: 1 | Label: 0 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.5064 | Prediction: 1 | Label: 0 | Abstained: True\n",
      "nctid: NCT02565628 | Probability: 0.3791 | Prediction: 0 | Label: 1 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.6734 | Prediction: 1 | Label: 1 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.3983 | Prediction: 0 | Label: 0 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.5559 | Prediction: 1 | Label: 1 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.4414 | Prediction: 0 | Label: 0 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.6434 | Prediction: 1 | Label: 0 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.3745 | Prediction: 0 | Label: 0 | Abstained: False\n",
      "nctid: NCT02565628 | Probability: 0.3876 | Prediction: 0 | Label: 0 | Abstained: False\n"
     ]
    }
   ],
   "source": [
    "train_csv = pd.read_csv(test_file, delimiter=',')\n",
    "nctids = list(train_csv['nctid'])\n",
    "\n",
    "_ncts = []\n",
    "_keeps = []\n",
    "_yhats = []\n",
    "_labels = []\n",
    "_smx = []\n",
    "\n",
    "for i in range(10):\n",
    "    rand_nct = np.random.choice(nctids)\n",
    "    _ncts.append(rand_nct)\n",
    "    # print(rand_nct, nctids.index(rand_nct))\n",
    "    idx = nctids.index(rand_nct)\n",
    "    _smx.append(smx[idx][1])\n",
    "    _keeps.append(np.max(smx[idx]) >= lhat)\n",
    "    _yhats.append([np.argmax(smx[idx])])\n",
    "    _labels.append(labels[idx])\n",
    "\n",
    "for i in range(len(_ncts)):\n",
    "    print(f\"nctid: {rand_nct} | Probability: {_smx[i]:.4f} | Prediction: {_yhats[i][0]} | Label: {_labels[i]} | Abstained: {~_keeps[i]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def selective_risk2(lam): \n",
    "    return (val_yhats[val_phats >= lam] != val_labels[val_phats >= lam]).sum()/(val_phats >= lam).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 500x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lambdas = np.linspace(0.3,0.7,1001)\n",
    "sns.set_theme(style='white')\n",
    "selective_accuracy_curve = [1-selective_risk2(lam) for lam in lambdas]\n",
    "fraction_kept_curve = [(val_phats >= lam).mean() for lam in lambdas]\n",
    "fig, axs = plt.subplots(1,1,figsize=(5,3))\n",
    "axs.plot(lambdas,selective_accuracy_curve,label='accuracy',linewidth=2)\n",
    "axs.plot(lambdas,fraction_kept_curve,label='fraction kept',linewidth=2)\n",
    "axs.axvline(x=lhat,linewidth=1.5,linestyle='--',label=r'$\\hat{\\lambda}$',color='gray')\n",
    "axs.axhline(y=1-alpha,linewidth=1.5,linestyle='dotted',label=r'$1-\\alpha$',color='gray')\n",
    "sns.despine(ax=axs,top=True,right=True)\n",
    "axs.legend(loc='lower left')\n",
    "axs.set_xlabel(r'$\\lambda$')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 600x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# PR-AUC   mean: 0.5765 std: 0.0174\n",
    "# F1       mean: 0.6003 std: 0.0218\n",
    "# ROC-AUC  mean: 0.5723 std: 0.0192\n",
    "\n",
    "origin = [0.5765, 0.6003, 0.5723, accuracy_score(test_predict_label, test_label)]\n",
    "cp = [prauc_score, f1score, auc_score, empirical_selective_accuracy]\n",
    "\n",
    "# X axis\n",
    "X_labels = [\"PR_AUC\", \"F1\", \"ROC_AUC\", \"Accuracy\"]\n",
    "\n",
    "data = {\n",
    "    'Metric': X_labels * 2,  \n",
    "    'Value': origin + cp,  \n",
    "    'Type': ['origin']*4 + ['cp']*4  \n",
    "}\n",
    "\n",
    "df = pd.DataFrame(data)\n",
    "\n",
    "plt.figure(figsize=(6, 4))\n",
    "ax = sns.barplot(x='Metric', y='Value', hue='Type', data=df)\n",
    "\n",
    "for p in ax.patches:\n",
    "    ax.annotate(f'{p.get_height():.4f}', \n",
    "                (p.get_x() + p.get_width() / 2., p.get_height()), \n",
    "                ha = 'center', \n",
    "                va = 'center', \n",
    "                xytext = (0, 9), \n",
    "                textcoords = 'offset points',\n",
    "                fontsize = 10)\n",
    "ax.legend(loc='upper left', bbox_to_anchor=(1, 1))\n",
    "plt.title('Phase III')\n",
    "plt.ylabel('Value')\n",
    "plt.xlabel('Metric')\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "# 显示图形\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "d2l",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}