--- a +++ b/task/NextVIsit-6month.ipynb @@ -0,0 +1,627 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.insert(0, '../')\n", + "\n", + "from common.common import create_folder,load_obj\n", + "from dataLoader.utils import seq_padding,code2index, position_idx, index_seg\n", + "from torch.utils.data import DataLoader\n", + "import pandas as pd\n", + "import numpy as np\n", + "from torch.utils.data.dataset import Dataset\n", + "import os\n", + "import torch\n", + "import torch.nn as nn\n", + "import pytorch_pretrained_bert as Bert\n", + "from model.utils import age_vocab\n", + "from model import optimiser\n", + "import sklearn.metrics as skm\n", + "import math\n", + "from torch.utils.data.dataset import Dataset\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "import time" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# File Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file_config = {\n", + " 'vocab':'', # vocab token2idx idx2token\n", + " 'train': '',\n", + " 'test': '',\n", + "}\n", + "\n", + "optim_config = {\n", + " 'lr': 3e-5,\n", + " 'warmup_proportion': 0.1,\n", + " 'weight_decay': 0.01\n", + "}\n", + "\n", + "global_params = {\n", + " 'batch_size': 64,\n", + " 'gradient_accumulation_steps': 1,\n", + " 'device': 'cuda:0',\n", + " 'output_dir': '', # output dir\n", + " 'best_name': '', # output model name\n", + " 'save_model': True,\n", + " 'max_len_seq': 100,\n", + " 'max_age': 110,\n", + " 'month': 1,\n", + " 'age_symbol': None,\n", + " 'min_visit': 5\n", + "}\n", + "\n", + "pretrainModel = '' # MLM pretrained model path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "create_folder(global_params['output_dir'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BertVocab = load_obj(file_config['vocab'])\n", + "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def format_label_vocab(token2idx):\n", + " token2idx = token2idx.copy()\n", + " del token2idx['PAD']\n", + " del token2idx['SEP']\n", + " del token2idx['CLS']\n", + " del token2idx['MASK']\n", + " token = list(token2idx.keys())\n", + " labelVocab = {}\n", + " for i,x in enumerate(token):\n", + " labelVocab[x] = i\n", + " return labelVocab\n", + "\n", + "Vocab_diag = format_label_vocab(BertVocab['token2idx'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_config = {\n", + " 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n", + " 'hidden_size': 288, # word embedding and seg embedding hidden size\n", + " 'seg_vocab_size': 2, # number of vocab for seg embedding\n", + " 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n", + " 'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n", + " 'hidden_dropout_prob': 0.2, # dropout rate\n", + " 'num_hidden_layers': 6, # number of multi-head attention layers required\n", + " 'num_attention_heads': 12, # number of attention heads\n", + " 'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate\n", + " 'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n", + " 'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n", + " 'initializer_range': 0.02, # parameter weight initializer range\n", + "}\n", + "\n", + "feature_dict = {\n", + " 'age': True,\n", + " 'seg': True,\n", + " 'posi': True\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set Up Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class NextVisit(Dataset):\n", + " def __init__(self, token2idx, diag2idx, age2idx,dataframe, max_len, max_age=110, min_visit=5):\n", + " # dataframe preproecssing\n", + " # filter out the patient with number of visits less than min_visit\n", + " self.vocab = token2idx\n", + " self.label_vocab = diag2idx\n", + " self.max_len = max_len\n", + " self.code = dataframe.code\n", + " self.age = dataframe.age\n", + " self.label = dataframe.label\n", + " self.patid = dataframe.patid\n", + "\n", + " self.age2idx = age2idx\n", + "\n", + " def __getitem__(self, index):\n", + " \"\"\"\n", + " return: age, code, position, segmentation, mask, label\n", + " \"\"\"\n", + " # cut data\n", + " age = self.age[index]\n", + " code = self.code[index]\n", + " label = self.label[index]\n", + " patid = self.patid[index]\n", + "\n", + " # extract data\n", + " age = age[(-self.max_len+1):]\n", + " code = code[(-self.max_len+1):]\n", + "\n", + " # avoid data cut with first element to be 'SEP'\n", + " if code[0] != 'SEP':\n", + " code = np.append(np.array(['CLS']), code)\n", + " age = np.append(np.array(age[0]), age)\n", + " else:\n", + " code[0] = 'CLS'\n", + "\n", + " # mask 0:len(code) to 1, padding to be 0\n", + " mask = np.ones(self.max_len)\n", + " mask[len(code):] = 0\n", + "\n", + " # pad age sequence and code sequence\n", + " age = seq_padding(age, self.max_len, token2idx=self.age2idx)\n", + "\n", + " tokens, code = code2index(code, self.vocab)\n", + " _, label = code2index(label, self.label_vocab)\n", + "\n", + " # get position code and segment code\n", + " tokens = seq_padding(tokens, self.max_len)\n", + " position = position_idx(tokens)\n", + " segment = index_seg(tokens)\n", + "\n", + " # pad code and label\n", + " code = seq_padding(code, self.max_len, symbol=self.vocab['PAD'])\n", + " label = seq_padding(label, self.max_len, symbol=-1)\n", + "\n", + " return torch.LongTensor(age), torch.LongTensor(code), torch.LongTensor(position), torch.LongTensor(segment), \\\n", + " torch.LongTensor(mask), torch.LongTensor(label), torch.LongTensor([int(patid)])\n", + "\n", + " def __len__(self):\n", + " return len(self.code)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BertConfig(Bert.modeling.BertConfig):\n", + " def __init__(self, config):\n", + " super(BertConfig, self).__init__(\n", + " vocab_size_or_config_json_file=config.get('vocab_size'),\n", + " hidden_size=config['hidden_size'],\n", + " num_hidden_layers=config.get('num_hidden_layers'),\n", + " num_attention_heads=config.get('num_attention_heads'),\n", + " intermediate_size=config.get('intermediate_size'),\n", + " hidden_act=config.get('hidden_act'),\n", + " hidden_dropout_prob=config.get('hidden_dropout_prob'),\n", + " attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n", + " max_position_embeddings = config.get('max_position_embedding'),\n", + " initializer_range=config.get('initializer_range'),\n", + " )\n", + " self.seg_vocab_size = config.get('seg_vocab_size')\n", + " self.age_vocab_size = config.get('age_vocab_size')\n", + "\n", + "class BertEmbeddings(nn.Module):\n", + " \"\"\"Construct the embeddings from word, segment, age\n", + " \"\"\"\n", + "\n", + " def __init__(self, config, feature_dict):\n", + " super(BertEmbeddings, self).__init__()\n", + " self.feature_dict = feature_dict\n", + " \n", + " self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n", + " self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)\n", + " self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)\n", + " self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\\\n", + " from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))\n", + "\n", + " self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)\n", + " self.dropout = nn.Dropout(config.hidden_dropout_prob)\n", + "\n", + " def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):\n", + " if seg_ids is None:\n", + " seg_ids = torch.zeros_like(word_ids)\n", + " if age_ids is None:\n", + " age_ids = torch.zeros_like(word_ids)\n", + " if posi_ids is None:\n", + " posi_ids = torch.zeros_like(word_ids)\n", + "\n", + " word_embed = self.word_embeddings(word_ids)\n", + " segment_embed = self.segment_embeddings(seg_ids)\n", + " age_embed = self.age_embeddings(age_ids)\n", + " posi_embeddings = self.posi_embeddings(posi_ids)\n", + " \n", + " embeddings = word_embed\n", + " \n", + " if self.feature_dict['age']:\n", + " embeddings = embeddings + age_embed\n", + " if self.feature_dict['seg']:\n", + " embeddings = embeddings + segment_embed\n", + " if self.feature_dict['posi']:\n", + " embeddings = embeddings + posi_embeddings\n", + " \n", + " embeddings = self.LayerNorm(embeddings)\n", + " embeddings = self.dropout(embeddings)\n", + " return embeddings\n", + "\n", + " def _init_posi_embedding(self, max_position_embedding, hidden_size):\n", + " def even_code(pos, idx):\n", + " return np.sin(pos/(10000**(2*idx/hidden_size)))\n", + "\n", + " def odd_code(pos, idx):\n", + " return np.cos(pos/(10000**(2*idx/hidden_size)))\n", + "\n", + " # initialize position embedding table\n", + " lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)\n", + "\n", + " # reset table parameters with hard encoding\n", + " # set even dimension\n", + " for pos in range(max_position_embedding):\n", + " for idx in np.arange(0, hidden_size, step=2):\n", + " lookup_table[pos, idx] = even_code(pos, idx)\n", + " # set odd dimension\n", + " for pos in range(max_position_embedding):\n", + " for idx in np.arange(1, hidden_size, step=2):\n", + " lookup_table[pos, idx] = odd_code(pos, idx)\n", + "\n", + " return torch.tensor(lookup_table)\n", + "\n", + "class BertModel(Bert.modeling.BertPreTrainedModel):\n", + " def __init__(self, config, feature_dict):\n", + " super(BertModel, self).__init__(config)\n", + " self.embeddings = BertEmbeddings(config, feature_dict)\n", + " self.encoder = Bert.modeling.BertEncoder(config=config)\n", + " self.pooler = Bert.modeling.BertPooler(config)\n", + " self.apply(self.init_bert_weights)\n", + "\n", + " def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):\n", + " if attention_mask is None:\n", + " attention_mask = torch.ones_like(input_ids)\n", + " if age_ids is None:\n", + " age_ids = torch.zeros_like(input_ids)\n", + " if seg_ids is None:\n", + " seg_ids = torch.zeros_like(input_ids)\n", + " if posi_ids is None:\n", + " posi_ids = torch.zeros_like(input_ids)\n", + " extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)\n", + " extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility\n", + " extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n", + "\n", + " embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)\n", + " encoded_layers = self.encoder(embedding_output,\n", + " extended_attention_mask,\n", + " output_all_encoded_layers=output_all_encoded_layers)\n", + " sequence_output = encoded_layers[-1]\n", + " pooled_output = self.pooler(sequence_output)\n", + " if not output_all_encoded_layers:\n", + " encoded_layers = encoded_layers[-1]\n", + " return encoded_layers, pooled_output\n", + "\n", + "class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):\n", + " def __init__(self, config, num_labels, feature_dict):\n", + " super(BertForMultiLabelPrediction, self).__init__(config)\n", + " self.num_labels = num_labels\n", + " self.bert = BertModel(config, feature_dict)\n", + " self.dropout = nn.Dropout(config.hidden_dropout_prob)\n", + " self.classifier = nn.Linear(config.hidden_size, num_labels)\n", + " self.apply(self.init_bert_weights)\n", + "\n", + " def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):\n", + " _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,\n", + " output_all_encoded_layers=False)\n", + " pooled_output = self.dropout(pooled_output)\n", + " logits = self.classifier(pooled_output)\n", + "\n", + " if labels is not None:\n", + " loss_fct = nn.MultiLabelSoftMarginLoss()\n", + " loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))\n", + " return loss, logits\n", + " else:\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_parquet(file_config['train']).reset_index(drop=True)\n", + "data['label'] = data.label.apply(lambda x: list(set(x)))\n", + "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n", + "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_parquet(file_config['test']).reset_index(drop=True)\n", + "data['label'] = data.label.apply(lambda x: list(set(x)))\n", + "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n", + "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Set Up Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conf = BertConfig(model_config)\n", + "model = BertForMultiLabelPrediction(conf, len(Vocab_diag.keys()), feature_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# load pretrained model and update weights\n", + "pretrained_dict = torch.load(pretrainModel)\n", + "model_dict = model.state_dict()\n", + "# 1. filter out unnecessary keys\n", + "pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n", + "# 2. overwrite entries in the existing state dict\n", + "model_dict.update(pretrained_dict) \n", + "# 3. load the new state dict\n", + "model.load_state_dict(model_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = model.to(global_params['device'])\n", + "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Evaluation Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "def precision(logits, label):\n", + " sig = nn.Sigmoid()\n", + " output=sig(logits)\n", + " label, output=label.cpu(), output.detach().cpu()\n", + " tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n", + " return tempprc, output, label\n", + "\n", + "def precision_test(logits, label):\n", + " sig = nn.Sigmoid()\n", + " output=sig(logits)\n", + " tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n", + "# roc = sklearn.metrics.roc_auc_score()\n", + " return tempprc, output, label\n", + "\n", + "def auroc_test(logits, label):\n", + " sig = nn.Sigmoid()\n", + " output=sig(logits)\n", + " tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n", + "# roc = sklearn.metrics.roc_auc_score()\n", + " return tempprc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi-hot Label Encoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.preprocessing import MultiLabelBinarizer\n", + "mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))\n", + "mlb.fit([[each] for each in list(Vocab_diag.values())])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train and Test" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train(e):\n", + " model.train()\n", + " tr_loss = 0\n", + " temp_loss = 0\n", + " nb_tr_examples, nb_tr_steps = 0, 0\n", + " cnt = 0\n", + " for step, batch in enumerate(trainload):\n", + " cnt +=1\n", + " age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n", + " targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n", + " \n", + " age_ids = age_ids.to(global_params['device'])\n", + " input_ids = input_ids.to(global_params['device'])\n", + " posi_ids = posi_ids.to(global_params['device'])\n", + " segment_ids = segment_ids.to(global_params['device'])\n", + " attMask = attMask.to(global_params['device'])\n", + " targets = targets.to(global_params['device'])\n", + " \n", + " loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n", + " \n", + " if global_params['gradient_accumulation_steps'] >1:\n", + " loss = loss/global_params['gradient_accumulation_steps']\n", + " loss.backward()\n", + " \n", + " temp_loss += loss.item()\n", + " tr_loss += loss.item()\n", + " nb_tr_examples += input_ids.size(0)\n", + " nb_tr_steps += 1\n", + " \n", + " if step % 2000==0:\n", + " prec, a, b = precision(logits, targets)\n", + " print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/2000, prec))\n", + " temp_loss = 0\n", + " \n", + " if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", + " optim.step()\n", + " optim.zero_grad()\n", + "\n", + "def evaluation():\n", + " model.eval()\n", + " y = []\n", + " y_label = []\n", + " for step, batch in enumerate(testload):\n", + " model.eval()\n", + " age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n", + " targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n", + " \n", + " age_ids = age_ids.to(global_params['device'])\n", + " input_ids = input_ids.to(global_params['device'])\n", + " posi_ids = posi_ids.to(global_params['device'])\n", + " segment_ids = segment_ids.to(global_params['device'])\n", + " attMask = attMask.to(global_params['device'])\n", + " targets = targets.to(global_params['device'])\n", + " \n", + " with torch.no_grad():\n", + " loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n", + " logits = logits.cpu()\n", + " targets = targets.cpu()\n", + "\n", + " y_label.append(targets)\n", + " y.append(logits)\n", + "\n", + " y_label = torch.cat(y_label, dim=0)\n", + " y = torch.cat(y, dim=0)\n", + "\n", + " tempprc, output, label = precision_test(y, y_label)\n", + " auroc = auroc_test(y, y_label)\n", + " return tempprc, auroc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(action='ignore')\n", + "optim_config = {\n", + " 'lr': 3e-6,\n", + " 'warmup_proportion': 0.1\n", + "}\n", + "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)\n", + "\n", + "best_pre = 0.512\n", + "for e in range(50):\n", + " train(e)\n", + " auc, roc= evaluation()\n", + " if auc >best_pre:\n", + " # Save a trained model\n", + " print(\"** ** * Saving fine - tuned model ** ** * \")\n", + " model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n", + " output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n", + " create_folder(global_params['output_dir'])\n", + " if global_params['save_model']:\n", + " torch.save(model_to_save.state_dict(), output_model_file)\n", + " best_pre = auc\n", + " print('precision : {}, auroc: {},'.format(auc, roc))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}