[bad60c]: / task / NextXVisit.ipynb

Download this file

388 lines (387 with data), 13.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.insert(0, '../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "import pandas as pd\n",
    "from common.common import create_folder,H5Recorder\n",
    "import numpy as np\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_pretrained_bert as Bert\n",
    "\n",
    "from model import optimiser\n",
    "import sklearn.metrics as skm\n",
    "import math\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from common.common import load_obj\n",
    "from model.utils import age_vocab\n",
    "from dataLoader.NextXVisit import NextVisit\n",
    "from model.NextXVisit import BertForMultiLabelPrediction\n",
    "import warnings\n",
    "warnings.filterwarnings(action='ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_config = {\n",
    "    'vocab': '', # token2idx idx2token\n",
    "    'train': '',\n",
    "    'test': '',\n",
    "}\n",
    "\n",
    "optim_config = {\n",
    "    'lr': 3e-5,\n",
    "    'warmup_proportion': 0.1,\n",
    "    'weight_decay': 0.01\n",
    "}\n",
    "\n",
    "global_params = {\n",
    "    'batch_size': 256,\n",
    "    'gradient_accumulation_steps': 1,\n",
    "    'device': 'cuda:0',\n",
    "    'output_dir': '', # output folder\n",
    "    'best_name': '',  # output model name\n",
    "    'max_len_seq': 100,\n",
    "    'max_age': 110,\n",
    "    'age_year': False,\n",
    "    'age_symbol': None,\n",
    "    'min_visit': 5\n",
    "}\n",
    "\n",
    "pretrain_model_path = ''  # pretrained MLM path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BertVocab = load_obj(file_config['vocab'])\n",
    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], symbol=global_params['age_symbol'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# re-format label token\n",
    "def format_label_vocab(token2idx):\n",
    "    token2idx = token2idx.copy()\n",
    "    del token2idx['PAD']\n",
    "    del token2idx['SEP']\n",
    "    del token2idx['CLS']\n",
    "    del token2idx['MASK']\n",
    "    token = list(token2idx.keys())\n",
    "    labelVocab = {}\n",
    "    for i,x in enumerate(token):\n",
    "        labelVocab[x] = i\n",
    "    return labelVocab\n",
    "\n",
    "labelVocab = format_label_vocab(BertVocab['token2idx'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = {\n",
    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
    "    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
    "    'hidden_dropout_prob': 0.1, # dropout rate\n",
    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
    "    'num_attention_heads': 12, # number of attention heads\n",
    "    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n",
    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
    "    'initializer_range': 0.02, # parameter weight initializer range\n",
    "}\n",
    "\n",
    "feature_dict = {\n",
    "    'word':True,\n",
    "    'seg':True,\n",
    "    'age':True,\n",
    "    'position': True\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertConfig(Bert.modeling.BertConfig):\n",
    "    def __init__(self, config):\n",
    "        super(BertConfig, self).__init__(\n",
    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
    "            hidden_size=config['hidden_size'],\n",
    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
    "            num_attention_heads=config.get('num_attention_heads'),\n",
    "            intermediate_size=config.get('intermediate_size'),\n",
    "            hidden_act=config.get('hidden_act'),\n",
    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
    "            max_position_embeddings = config.get('max_position_embedding'),\n",
    "            initializer_range=config.get('initializer_range'),\n",
    "        )\n",
    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
    "        self.age_vocab_size = config.get('age_vocab_size')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_parquet(file_config['train'])\n",
    "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'])\n",
    "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_parquet(file_config['test'])\n",
    "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=test, max_len=global_params['max_len_seq'])\n",
    "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# del model\n",
    "conf = BertConfig(model_config)\n",
    "model = BertForMultiLabelPrediction(conf, num_labels=len(labelVocab.keys()), feature_dict=feature_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(path, model):\n",
    "    # load pretrained model and update weights\n",
    "    pretrained_dict = torch.load(path)\n",
    "    model_dict = model.state_dict()\n",
    "    # 1. filter out unnecessary keys\n",
    "    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
    "    # 2. overwrite entries in the existing state dict\n",
    "    model_dict.update(pretrained_dict)\n",
    "    # 3. load the new state dict\n",
    "    model.load_state_dict(model_dict)\n",
    "    return model\n",
    "\n",
    "mode = load_model(pretrain_model_path, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(global_params['device'])\n",
    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sklearn\n",
    "def precision(logits, label):\n",
    "    sig = nn.Sigmoid()\n",
    "    output=sig(logits)\n",
    "    label, output=label.cpu(), output.detach().cpu()\n",
    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
    "    return tempprc, output, label\n",
    "\n",
    "def precision_test(logits, label):\n",
    "    sig = nn.Sigmoid()\n",
    "    output=sig(logits)\n",
    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
    "    roc = sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n",
    "    return tempprc, roc, output, label,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
    "mlb = MultiLabelBinarizer(classes=list(labelVocab.values()))\n",
    "mlb.fit([[each] for each in list(labelVocab.values())])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(e):\n",
    "    model.train()\n",
    "    tr_loss = 0\n",
    "    temp_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    cnt = 0\n",
    "    for step, batch in enumerate(trainload):\n",
    "        cnt +=1\n",
    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
    "        \n",
    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
    "\n",
    "        age_ids = age_ids.to(global_params['device'])\n",
    "        input_ids = input_ids.to(global_params['device'])\n",
    "        posi_ids = posi_ids.to(global_params['device'])\n",
    "        segment_ids = segment_ids.to(global_params['device'])\n",
    "        attMask = attMask.to(global_params['device'])\n",
    "        targets = targets.to(global_params['device'])\n",
    "        \n",
    "        loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
    "        \n",
    "        if global_params['gradient_accumulation_steps'] >1:\n",
    "            loss = loss/global_params['gradient_accumulation_steps']\n",
    "        loss.backward()\n",
    "        \n",
    "        temp_loss += loss.item()\n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        \n",
    "        if step % 500==0:\n",
    "            prec, a, b = precision(logits, targets)\n",
    "            print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/500, prec))\n",
    "            temp_loss = 0\n",
    "        \n",
    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
    "            optim.step()\n",
    "            optim.zero_grad()\n",
    "\n",
    "def evaluation():\n",
    "    model.eval()\n",
    "    y = []\n",
    "    y_label = []\n",
    "    tr_loss = 0\n",
    "    for step, batch in enumerate(testload):\n",
    "        model.eval()\n",
    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
    "        \n",
    "        age_ids = age_ids.to(global_params['device'])\n",
    "        input_ids = input_ids.to(global_params['device'])\n",
    "        posi_ids = posi_ids.to(global_params['device'])\n",
    "        segment_ids = segment_ids.to(global_params['device'])\n",
    "        attMask = attMask.to(global_params['device'])\n",
    "        targets = targets.to(global_params['device'])\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
    "        logits = logits.cpu()\n",
    "        targets = targets.cpu()\n",
    "        \n",
    "        tr_loss += loss.item()\n",
    "\n",
    "        y_label.append(targets)\n",
    "        y.append(logits)\n",
    "\n",
    "    y_label = torch.cat(y_label, dim=0)\n",
    "    y = torch.cat(y, dim=0)\n",
    "\n",
    "    aps, roc, output, label = precision_test(y, y_label)\n",
    "    return aps, roc, tr_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "best_pre = 0.0\n",
    "for e in range(50):\n",
    "    train(e)\n",
    "    aps, roc, test_loss = evaluation()\n",
    "    if aps >best_pre:\n",
    "        # Save a trained model\n",
    "        print(\"** ** * Saving fine - tuned model ** ** * \")\n",
    "        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
    "        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n",
    "        create_folder(global_params['output_dir'])\n",
    "\n",
    "        torch.save(model_to_save.state_dict(), output_model_file)\n",
    "        best_pre = aps\n",
    "    print('aps : {}'.format(aps))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}