[221dc3]: / tutorial_HINT.ipynb

Download this file

225 lines (224 with data), 6.9 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: HINT for approval prediction on toy model\n",
    "\n",
    "In this tutorial, we show how to predict clinical trial approval step by step. \n",
    "\n",
    "Agenda:\n",
    "\n",
    "- Part I: Import modules\n",
    "- Part II: Specify task\n",
    "- Part III: Pretraining \n",
    "- Part IV: Data loader\n",
    "- Part V: Raw data embedding\n",
    "- Part VI: Learn and inference\n",
    "\n",
    "Let's start!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Import modules\n",
    "\n",
    "It includes \n",
    "\n",
    "- Standard modules, e.g., ```PyTorch```. \n",
    "- self-defined module, e.g., ```dataloader```, ```model```. \n",
    "\n",
    "We also specify the device (**CPU or GPU**) using ```device = torch.device(\"cpu\")``` or ```device = torch.device(\"cuda:0\")```."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, os\n",
    "torch.manual_seed(0) \n",
    "import warnings;warnings.filterwarnings(\"ignore\")\n",
    "from HINT.dataloader import csv_three_feature_2_dataloader, generate_admet_dataloader_lst\n",
    "from HINT.molecule_encode import MPNN, ADMET \n",
    "from HINT.icdcode_encode import GRAM, build_icdcode2ancestor_dict\n",
    "from HINT.protocol_encode import Protocol_Embedding\n",
    "from HINT.model import HINTModel \n",
    "device = torch.device(\"cpu\")  ## cuda:0\n",
    "if not os.path.exists(\"figure\"):\n",
    "\tos.makedirs(\"figure\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Specify task\n",
    "The task includes:  \n",
    "- phase I prediction\n",
    "- phase II prediction \n",
    "- phase III prediction\n",
    "- Indication prediction\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_name = 'toy' ### 'toy', 'phase_I', 'phase_II', 'phase_III', 'indication'\n",
    "datafolder = \"data\"\n",
    "train_file = os.path.join(datafolder, base_name + '_train.csv')\n",
    "valid_file = os.path.join(datafolder, base_name + '_valid.csv')\n",
    "test_file = os.path.join(datafolder, base_name + '_test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Pretrain\n",
    "\n",
    "ADMET stands for \n",
    "- Absorption \n",
    "- Distribution\n",
    "- Metabolism\n",
    "- Excretion \n",
    "- Toxicity \n",
    "\n",
    "The prediction of the ADMET properties plays an important role in the drug design process because these properties account for the failure of about 60% of all drugs in the clinical phases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "mpnn_model = MPNN(mpnn_hidden_size = 50, mpnn_depth=3, device = device)\n",
    "admet_model_path = \"save_model/admet_model.ckpt\"\n",
    "if not os.path.exists(admet_model_path):\n",
    "\tadmet_dataloader_lst = generate_admet_dataloader_lst(batch_size=32)\n",
    "\tadmet_trainloader_lst = [i[0] for i in admet_dataloader_lst]\n",
    "\tadmet_testloader_lst = [i[1] for i in admet_dataloader_lst]\n",
    "\tadmet_model = ADMET(molecule_encoder = mpnn_model, \n",
    "\t\t\t\t\t\thighway_num=2, \n",
    "\t\t\t\t\t\tdevice = device, \n",
    "\t\t\t\t\t\tepoch=3, \n",
    "\t\t\t\t\t\tlr=5e-4, \n",
    "\t\t\t\t\t\tweight_decay=0, \n",
    "\t\t\t\t\t\tsave_name = 'admet_')\n",
    "\tadmet_model.train(admet_trainloader_lst, admet_testloader_lst)\n",
    "\ttorch.save(admet_model, admet_model_path)\n",
    "else:\n",
    "\tadmet_model = torch.load(admet_model_path)\n",
    "\tadmet_model = admet_model.to(device)\n",
    "\tadmet_model.set_device(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Data loader\n",
    "\n",
    "Then we define data loaders for training, validation and test data, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = csv_three_feature_2_dataloader(train_file, shuffle=True, batch_size=32) \n",
    "valid_loader = csv_three_feature_2_dataloader(valid_file, shuffle=False, batch_size=32) \n",
    "test_loader = csv_three_feature_2_dataloader(test_file, shuffle=False, batch_size=32) \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Raw data embedding\n",
    "\n",
    "Then we defined data embeddor for [ICD-10 code](https://en.wikipedia.org/wiki/ICD-10) that represent hierarchy of disease code and trial protocol (eligibility criteria). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "icdcode2ancestor_dict = build_icdcode2ancestor_dict()\n",
    "gram_model = GRAM(embedding_dim = 50, icdcode2ancestor = icdcode2ancestor_dict, device = device)\n",
    "protocol_model = Protocol_Embedding(output_dim = 50, highway_num=3, device = device)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Learn and Inference\n",
    "\n",
    "Then we describe the learning and inference process. The trained model that performs best in validation set would be saved. If there exists saved model, then we can conduct inference directly. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hint_model_path = \"save_model/\" + base_name + \".ckpt\"\n",
    "if not os.path.exists(hint_model_path):\n",
    "\tmodel = HINTModel(molecule_encoder = mpnn_model, \n",
    "\t\t\t disease_encoder = gram_model, \n",
    "\t\t\t protocol_encoder = protocol_model,\n",
    "\t\t\t device = device, \n",
    "\t\t\t global_embed_size = 50, \n",
    "\t\t\t highway_num_layer = 2,\n",
    "\t\t\t prefix_name = base_name, \n",
    "\t\t\t gnn_hidden_size = 50,  \n",
    "\t\t\t epoch = 3,\n",
    "\t\t\t lr = 1e-3, \n",
    "\t\t\t weight_decay = 0, \n",
    "\t\t\t)\n",
    "\tmodel.init_pretrain(admet_model)\n",
    "\tmodel.learn(train_loader, valid_loader, test_loader)\n",
    "\tmodel.bootstrap_test(test_loader)\n",
    "\ttorch.save(model, hint_model_path)\n",
    "else:\n",
    "\tmodel = torch.load(hint_model_path)\n",
    "\tmodel.bootstrap_test(test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}