[7e250a]: / src / hint / notebooks / ablations.ipynb

Download this file

289 lines (288 with data), 44.9 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ../"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os\n",
    "torch.manual_seed(0) \n",
    "import warnings;warnings.filterwarnings(\"ignore\")\n",
    "from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst\n",
    "from HINT.molecule_encode import MPNN, ADMET \n",
    "from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict\n",
    "from HINT.protocol.model import ProtocolEmbedding\n",
    "from HINT.model import HINTModel\n",
    "device = torch.device(\"cuda:0\")\n",
    "if not os.path.exists(\"figure\"):\n",
    "\tos.makedirs(\"figure\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ablations = [\n",
    "    {\"name\": \"disease_pred\",\n",
    "      \"config\": {\n",
    "         \"disease_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"molecule_pred\",\n",
    "      \"config\": {\n",
    "         \"molecule_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"protocol_pred\",\n",
    "      \"config\": {\n",
    "         \"protocol_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"interaction_pred\",\n",
    "      \"config\": {\n",
    "         \"interaction_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"disease_risk_pred\",\n",
    "      \"config\": {\n",
    "         \"disease_risk_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"augmented_interaction_pred\",\n",
    "      \"config\": {\n",
    "         \"augmented_interaction_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"pharmacokinetics_pred\",\n",
    "      \"config\": {\n",
    "         \"pk_embedding\": True\n",
    "     }\n",
    "    },\n",
    "     {\"name\": \"trial_pred_nn\",\n",
    "      \"config\": {\n",
    "         \"trial_embedding\": True\n",
    "     }\n",
    "    },\n",
    "    {\"name\": \"vanilla\",\n",
    "      \"config\": {\n",
    "         \"base_model\": True\n",
    "     }\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "ablations = [\n",
    "      {\"name\": \"disease_pred\",\n",
    "      \"config\": {\n",
    "         \"disease_embedding\": True\n",
    "     }\n",
    "    },\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "phase_name = 'phase_III'\n",
    "model_name = 'icd_protocol'\n",
    "datafolder = \"data\"\n",
    "train_file = os.path.join(datafolder, phase_name + '_train.csv')\n",
    "valid_file = os.path.join(datafolder, phase_name + '_valid.csv')\n",
    "test_file = os.path.join(datafolder, phase_name + '_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)\n",
    "admet_model_path = \"save_model/admet_model.ckpt\"\n",
    "if not os.path.exists(admet_model_path):\n",
    "\tadmet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)\n",
    "\tadmet_trainloader_lst = [i[0] for i in admet_dataloader_lst]\n",
    "\tadmet_testloader_lst = [i[1] for i in admet_dataloader_lst]\n",
    "\tadmet_model = ADMET(molecule_encoder = mpnn_model, \n",
    "\t\t\t\t\t\thighway_num=2, \n",
    "\t\t\t\t\t\tdevice = device, \n",
    "\t\t\t\t\t\tepoch=3, \n",
    "\t\t\t\t\t\tlr=5e-4, \n",
    "\t\t\t\t\t\tweight_decay=0, \n",
    "\t\t\t\t\t\tsave_name = 'admet_')\n",
    "\tadmet_model.train(admet_trainloader_lst, admet_testloader_lst)\n",
    "\ttorch.save(admet_model, admet_model_path)\n",
    "else:\n",
    "\tadmet_model = torch.load(admet_model_path)\n",
    "\tadmet_model = admet_model.to(device)\n",
    "\tadmet_model.set_device(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) \n",
    "valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) \n",
    "test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "icdcode2ancestor_dict = build_icdcode2ancestor_dict()\n",
    "gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)\n",
    "protocol_model = ProtocolEmbedding(output_dim = 50, highway_num=3, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running ablation disease_pred...\n",
      "PR-AUC   mean: 0.7714 std: 0.0118\n",
      "F1       mean: 0.8320 std: 0.0081\n",
      "ROC-AUC  mean: 0.6719 std: 0.0158\n",
      "Accuracy mean: 0.7740 std: 0.0120\n"
     ]
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for ablation in ablations:\n",
    "    hint_model_path = \"save_model/\" + ablation[\"name\"] + \".ckpt\"\n",
    "    print(f\"Running ablation {ablation['name']}...\")\n",
    "    \n",
    "    if not os.path.exists(hint_model_path):\n",
    "        model = HINTModel(molecule_encoder = mpnn_model, \n",
    "                 disease_encoder = gram_model, \n",
    "                 protocol_encoder = protocol_model,\n",
    "                 device = device, \n",
    "                 global_embed_size = 50, \n",
    "                 highway_num_layer = 2,\n",
    "                 prefix_name = model_name, \n",
    "                 gnn_hidden_size = 50,  \n",
    "                 epoch = 3,\n",
    "                 lr = 1e-3, \n",
    "                 weight_decay = 0,\n",
    "                 ablations=ablation\n",
    "                )\n",
    "        model.init_pretrain(admet_model)\n",
    "        model.learn(train_loader, valid_loader, test_loader)\n",
    "        model.bootstrap_test(test_loader, sample_num = 50)\n",
    "        torch.save(model, hint_model_path)\n",
    "    else:\n",
    "        model = torch.load(hint_model_path)\n",
    "        nctid_all, predict_all = model.bootstrap_test(test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Counter({1: 996, 0: 150})\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "predictions = [0 if pred < 0.5 else 1 for pred in predict_all]\n",
    "prediction_counts = Counter(predictions)\n",
    "\n",
    "print(prediction_counts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}