--- a +++ b/task/NextXVisit.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.insert(0, '../')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "import pandas as pd\n", + "from common.common import create_folder,H5Recorder\n", + "import numpy as np\n", + "from torch.utils.data.dataset import Dataset\n", + "import os\n", + "import torch\n", + "import torch.nn as nn\n", + "import pytorch_pretrained_bert as Bert\n", + "\n", + "from model import optimiser\n", + "import sklearn.metrics as skm\n", + "import math\n", + "from torch.utils.data.dataset import Dataset\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "import time\n", + "from sklearn.metrics import roc_auc_score\n", + "from common.common import load_obj\n", + "from model.utils import age_vocab\n", + "from dataLoader.NextXVisit import NextVisit\n", + "from model.NextXVisit import BertForMultiLabelPrediction\n", + "import warnings\n", + "warnings.filterwarnings(action='ignore')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file_config = {\n", + " 'vocab': '', # token2idx idx2token\n", + " 'train': '',\n", + " 'test': '',\n", + "}\n", + "\n", + "optim_config = {\n", + " 'lr': 3e-5,\n", + " 'warmup_proportion': 0.1,\n", + " 'weight_decay': 0.01\n", + "}\n", + "\n", + "global_params = {\n", + " 'batch_size': 256,\n", + " 'gradient_accumulation_steps': 1,\n", + " 'device': 'cuda:0',\n", + " 'output_dir': '', # output folder\n", + " 'best_name': '', # output model name\n", + " 'max_len_seq': 100,\n", + " 'max_age': 110,\n", + " 'age_year': False,\n", + " 'age_symbol': None,\n", + " 'min_visit': 5\n", + "}\n", + "\n", + "pretrain_model_path = '' # pretrained MLM path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BertVocab = load_obj(file_config['vocab'])\n", + "ageVocab, _ = age_vocab(max_age=global_params['max_age'], symbol=global_params['age_symbol'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# re-format label token\n", + "def format_label_vocab(token2idx):\n", + " token2idx = token2idx.copy()\n", + " del token2idx['PAD']\n", + " del token2idx['SEP']\n", + " del token2idx['CLS']\n", + " del token2idx['MASK']\n", + " token = list(token2idx.keys())\n", + " labelVocab = {}\n", + " for i,x in enumerate(token):\n", + " labelVocab[x] = i\n", + " return labelVocab\n", + "\n", + "labelVocab = format_label_vocab(BertVocab['token2idx'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_config = {\n", + " 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n", + " 'hidden_size': 288, # word embedding and seg embedding hidden size\n", + " 'seg_vocab_size': 2, # number of vocab for seg embedding\n", + " 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n", + " 'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n", + " 'hidden_dropout_prob': 0.1, # dropout rate\n", + " 'num_hidden_layers': 6, # number of multi-head attention layers required\n", + " 'num_attention_heads': 12, # number of attention heads\n", + " 'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n", + " 'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n", + " 'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n", + " 'initializer_range': 0.02, # parameter weight initializer range\n", + "}\n", + "\n", + "feature_dict = {\n", + " 'word':True,\n", + " 'seg':True,\n", + " 'age':True,\n", + " 'position': True\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BertConfig(Bert.modeling.BertConfig):\n", + " def __init__(self, config):\n", + " super(BertConfig, self).__init__(\n", + " vocab_size_or_config_json_file=config.get('vocab_size'),\n", + " hidden_size=config['hidden_size'],\n", + " num_hidden_layers=config.get('num_hidden_layers'),\n", + " num_attention_heads=config.get('num_attention_heads'),\n", + " intermediate_size=config.get('intermediate_size'),\n", + " hidden_act=config.get('hidden_act'),\n", + " hidden_dropout_prob=config.get('hidden_dropout_prob'),\n", + " attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n", + " max_position_embeddings = config.get('max_position_embedding'),\n", + " initializer_range=config.get('initializer_range'),\n", + " )\n", + " self.seg_vocab_size = config.get('seg_vocab_size')\n", + " self.age_vocab_size = config.get('age_vocab_size')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train = pd.read_parquet(file_config['train'])\n", + "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=train, max_len=global_params['max_len_seq'])\n", + "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test = pd.read_parquet(file_config['test'])\n", + "Dset = NextVisit(token2idx=BertVocab['token2idx'], label2idx=labelVocab, age2idx=ageVocab, dataframe=test, max_len=global_params['max_len_seq'])\n", + "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# del model\n", + "conf = BertConfig(model_config)\n", + "model = BertForMultiLabelPrediction(conf, num_labels=len(labelVocab.keys()), feature_dict=feature_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def load_model(path, model):\n", + " # load pretrained model and update weights\n", + " pretrained_dict = torch.load(path)\n", + " model_dict = model.state_dict()\n", + " # 1. filter out unnecessary keys\n", + " pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n", + " # 2. overwrite entries in the existing state dict\n", + " model_dict.update(pretrained_dict)\n", + " # 3. load the new state dict\n", + " model.load_state_dict(model_dict)\n", + " return model\n", + "\n", + "mode = load_model(pretrain_model_path, model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = model.to(global_params['device'])\n", + "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "def precision(logits, label):\n", + " sig = nn.Sigmoid()\n", + " output=sig(logits)\n", + " label, output=label.cpu(), output.detach().cpu()\n", + " tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n", + " return tempprc, output, label\n", + "\n", + "def precision_test(logits, label):\n", + " sig = nn.Sigmoid()\n", + " output=sig(logits)\n", + " tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n", + " roc = sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n", + " return tempprc, roc, output, label," + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import MultiLabelBinarizer\n", + "mlb = MultiLabelBinarizer(classes=list(labelVocab.values()))\n", + "mlb.fit([[each] for each in list(labelVocab.values())])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train(e):\n", + " model.train()\n", + " tr_loss = 0\n", + " temp_loss = 0\n", + " nb_tr_examples, nb_tr_steps = 0, 0\n", + " cnt = 0\n", + " for step, batch in enumerate(trainload):\n", + " cnt +=1\n", + " age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n", + " \n", + " targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n", + "\n", + " age_ids = age_ids.to(global_params['device'])\n", + " input_ids = input_ids.to(global_params['device'])\n", + " posi_ids = posi_ids.to(global_params['device'])\n", + " segment_ids = segment_ids.to(global_params['device'])\n", + " attMask = attMask.to(global_params['device'])\n", + " targets = targets.to(global_params['device'])\n", + " \n", + " loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n", + " \n", + " if global_params['gradient_accumulation_steps'] >1:\n", + " loss = loss/global_params['gradient_accumulation_steps']\n", + " loss.backward()\n", + " \n", + " temp_loss += loss.item()\n", + " tr_loss += loss.item()\n", + " nb_tr_examples += input_ids.size(0)\n", + " nb_tr_steps += 1\n", + " \n", + " if step % 500==0:\n", + " prec, a, b = precision(logits, targets)\n", + " print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/500, prec))\n", + " temp_loss = 0\n", + " \n", + " if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", + " optim.step()\n", + " optim.zero_grad()\n", + "\n", + "def evaluation():\n", + " model.eval()\n", + " y = []\n", + " y_label = []\n", + " tr_loss = 0\n", + " for step, batch in enumerate(testload):\n", + " model.eval()\n", + " age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n", + " targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n", + " \n", + " age_ids = age_ids.to(global_params['device'])\n", + " input_ids = input_ids.to(global_params['device'])\n", + " posi_ids = posi_ids.to(global_params['device'])\n", + " segment_ids = segment_ids.to(global_params['device'])\n", + " attMask = attMask.to(global_params['device'])\n", + " targets = targets.to(global_params['device'])\n", + " \n", + " with torch.no_grad():\n", + " loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n", + " logits = logits.cpu()\n", + " targets = targets.cpu()\n", + " \n", + " tr_loss += loss.item()\n", + "\n", + " y_label.append(targets)\n", + " y.append(logits)\n", + "\n", + " y_label = torch.cat(y_label, dim=0)\n", + " y = torch.cat(y, dim=0)\n", + "\n", + " aps, roc, output, label = precision_test(y, y_label)\n", + " return aps, roc, tr_loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "best_pre = 0.0\n", + "for e in range(50):\n", + " train(e)\n", + " aps, roc, test_loss = evaluation()\n", + " if aps >best_pre:\n", + " # Save a trained model\n", + " print(\"** ** * Saving fine - tuned model ** ** * \")\n", + " model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n", + " output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n", + " create_folder(global_params['output_dir'])\n", + "\n", + " torch.save(model_to_save.state_dict(), output_model_file)\n", + " best_pre = aps\n", + " print('aps : {}'.format(aps))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}