Diff of /task/NextXVisit.ipynb [000000] .. [bad60c]

Switch to side-by-side view

--- a
+++ b/task/NextXVisit.ipynb
@@ -0,0 +1,387 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys \n",
+    "sys.path.insert(0, '../')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch.utils.data import DataLoader\n",
+    "import pandas as pd\n",
+    "from common.common import create_folder,H5Recorder\n",
+    "import numpy as np\n",
+    "from torch.utils.data.dataset import Dataset\n",
+    "import os\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import pytorch_pretrained_bert as Bert\n",
+    "\n",
+    "from model import optimiser\n",
+    "import sklearn.metrics as skm\n",
+    "import math\n",
+    "from torch.utils.data.dataset import Dataset\n",
+    "import random\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import time\n",
+    "from sklearn.metrics import roc_auc_score\n",
+    "from common.common import load_obj\n",
+    "from model.utils import age_vocab\n",
+    "from dataLoader.NextXVisit import NextVisit\n",
+    "from model.NextXVisit import BertForMultiLabelPrediction\n",
+    "import warnings\n",
+    "warnings.filterwarnings(action='ignore')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "file_config = {\n",
+    "    'vocab': '', # token2idx idx2token\n",
+    "    'train': '',\n",
+    "    'test': '',\n",
+    "}\n",
+    "\n",
+    "optim_config = {\n",
+    "    'lr': 3e-5,\n",
+    "    'warmup_proportion': 0.1,\n",
+    "    'weight_decay': 0.01\n",
+    "}\n",
+    "\n",
+    "global_params = {\n",
+    "    'batch_size': 256,\n",
+    "    'gradient_accumulation_steps': 1,\n",
+    "    'device': 'cuda:0',\n",
+    "    'output_dir': '', # output folder\n",
+    "    'best_name': '',  # output model name\n",
+    "    'max_len_seq': 100,\n",
+    "    'max_age': 110,\n",
+    "    'age_year': False,\n",
+    "    'age_symbol': None,\n",
+    "    'min_visit': 5\n",
+    "}\n",
+    "\n",
+    "pretrain_model_path = ''  # pretrained MLM path"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "BertVocab = load_obj(file_config['vocab'])\n",
+    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], symbol=global_params['age_symbol'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# re-format label token\n",
+    "def format_label_vocab(token2idx):\n",
+    "    token2idx = token2idx.copy()\n",
+    "    del token2idx['PAD']\n",
+    "    del token2idx['SEP']\n",
+    "    del token2idx['CLS']\n",
+    "    del token2idx['MASK']\n",
+    "    token = list(token2idx.keys())\n",
+    "    labelVocab = {}\n",
+    "    for i,x in enumerate(token):\n",
+    "        labelVocab[x] = i\n",
+    "    return labelVocab\n",
+    "\n",
+    "labelVocab = format_label_vocab(BertVocab['token2idx'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_config = {\n",
+    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
+    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
+    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
+    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
+    "    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
+    "    'hidden_dropout_prob': 0.1, # dropout rate\n",
+    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
+    "    'num_attention_heads': 12, # number of attention heads\n",
+    "    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n",
+    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
+    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
+    "    'initializer_range': 0.02, # parameter weight initializer range\n",
+    "}\n",
+    "\n",
+    "feature_dict = {\n",
+    "    'word':True,\n",
+    "    'seg':True,\n",
+    "    'age':True,\n",
+    "    'position': True\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BertConfig(Bert.modeling.BertConfig):\n",
+    "    def __init__(self, config):\n",
+    "        super(BertConfig, self).__init__(\n",
+    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
+    "            hidden_size=config['hidden_size'],\n",
+    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
+    "            num_attention_heads=config.get('num_attention_heads'),\n",
+    "            intermediate_size=config.get('intermediate_size'),\n",
+    "            hidden_act=config.get('hidden_act'),\n",
+    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
+    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
+    "            max_position_embeddings = config.get('max_position_embedding'),\n",
+    "            initializer_range=config.get('initializer_range'),\n",
+    "        )\n",
+    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
+    "        self.age_vocab_size = config.get('age_vocab_size')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "train = pd.read_parquet(file_config['train'])\n",
+    "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'])\n",
+    "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "test = pd.read_parquet(file_config['test'])\n",
+    "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=test, max_len=global_params['max_len_seq'])\n",
+    "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# del model\n",
+    "conf = BertConfig(model_config)\n",
+    "model = BertForMultiLabelPrediction(conf, num_labels=len(labelVocab.keys()), feature_dict=feature_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def load_model(path, model):\n",
+    "    # load pretrained model and update weights\n",
+    "    pretrained_dict = torch.load(path)\n",
+    "    model_dict = model.state_dict()\n",
+    "    # 1. filter out unnecessary keys\n",
+    "    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
+    "    # 2. overwrite entries in the existing state dict\n",
+    "    model_dict.update(pretrained_dict)\n",
+    "    # 3. load the new state dict\n",
+    "    model.load_state_dict(model_dict)\n",
+    "    return model\n",
+    "\n",
+    "mode = load_model(pretrain_model_path, model)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = model.to(global_params['device'])\n",
+    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sklearn\n",
+    "def precision(logits, label):\n",
+    "    sig = nn.Sigmoid()\n",
+    "    output=sig(logits)\n",
+    "    label, output=label.cpu(), output.detach().cpu()\n",
+    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
+    "    return tempprc, output, label\n",
+    "\n",
+    "def precision_test(logits, label):\n",
+    "    sig = nn.Sigmoid()\n",
+    "    output=sig(logits)\n",
+    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
+    "    roc = sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n",
+    "    return tempprc, roc, output, label,"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.preprocessing import MultiLabelBinarizer\n",
+    "mlb = MultiLabelBinarizer(classes=list(labelVocab.values()))\n",
+    "mlb.fit([[each] for each in list(labelVocab.values())])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(e):\n",
+    "    model.train()\n",
+    "    tr_loss = 0\n",
+    "    temp_loss = 0\n",
+    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
+    "    cnt = 0\n",
+    "    for step, batch in enumerate(trainload):\n",
+    "        cnt +=1\n",
+    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
+    "        \n",
+    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
+    "\n",
+    "        age_ids = age_ids.to(global_params['device'])\n",
+    "        input_ids = input_ids.to(global_params['device'])\n",
+    "        posi_ids = posi_ids.to(global_params['device'])\n",
+    "        segment_ids = segment_ids.to(global_params['device'])\n",
+    "        attMask = attMask.to(global_params['device'])\n",
+    "        targets = targets.to(global_params['device'])\n",
+    "        \n",
+    "        loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
+    "        \n",
+    "        if global_params['gradient_accumulation_steps'] >1:\n",
+    "            loss = loss/global_params['gradient_accumulation_steps']\n",
+    "        loss.backward()\n",
+    "        \n",
+    "        temp_loss += loss.item()\n",
+    "        tr_loss += loss.item()\n",
+    "        nb_tr_examples += input_ids.size(0)\n",
+    "        nb_tr_steps += 1\n",
+    "        \n",
+    "        if step % 500==0:\n",
+    "            prec, a, b = precision(logits, targets)\n",
+    "            print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/500, prec))\n",
+    "            temp_loss = 0\n",
+    "        \n",
+    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
+    "            optim.step()\n",
+    "            optim.zero_grad()\n",
+    "\n",
+    "def evaluation():\n",
+    "    model.eval()\n",
+    "    y = []\n",
+    "    y_label = []\n",
+    "    tr_loss = 0\n",
+    "    for step, batch in enumerate(testload):\n",
+    "        model.eval()\n",
+    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
+    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
+    "        \n",
+    "        age_ids = age_ids.to(global_params['device'])\n",
+    "        input_ids = input_ids.to(global_params['device'])\n",
+    "        posi_ids = posi_ids.to(global_params['device'])\n",
+    "        segment_ids = segment_ids.to(global_params['device'])\n",
+    "        attMask = attMask.to(global_params['device'])\n",
+    "        targets = targets.to(global_params['device'])\n",
+    "        \n",
+    "        with torch.no_grad():\n",
+    "            loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
+    "        logits = logits.cpu()\n",
+    "        targets = targets.cpu()\n",
+    "        \n",
+    "        tr_loss += loss.item()\n",
+    "\n",
+    "        y_label.append(targets)\n",
+    "        y.append(logits)\n",
+    "\n",
+    "    y_label = torch.cat(y_label, dim=0)\n",
+    "    y = torch.cat(y, dim=0)\n",
+    "\n",
+    "    aps, roc, output, label = precision_test(y, y_label)\n",
+    "    return aps, roc, tr_loss"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "best_pre = 0.0\n",
+    "for e in range(50):\n",
+    "    train(e)\n",
+    "    aps, roc, test_loss = evaluation()\n",
+    "    if aps >best_pre:\n",
+    "        # Save a trained model\n",
+    "        print(\"** ** * Saving fine - tuned model ** ** * \")\n",
+    "        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
+    "        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n",
+    "        create_folder(global_params['output_dir'])\n",
+    "\n",
+    "        torch.save(model_to_save.state_dict(), output_model_file)\n",
+    "        best_pre = aps\n",
+    "    print('aps : {}'.format(aps))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}