622 lines (621 with data), 23.6 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys \n",
"sys.path.insert(0, '../')\n",
"\n",
"from common.common import create_folder,load_obj\n",
"# from data import bert,dataframe,utils\n",
"from dataLoader.utils import seq_padding,code2index, position_idx, index_seg\n",
"from torch.utils.data import DataLoader\n",
"import pandas as pd\n",
"import numpy as np\n",
"from torch.utils.data.dataset import Dataset\n",
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_pretrained_bert as Bert\n",
"from model.utils import age_vocab\n",
"from model import optimiser\n",
"import sklearn.metrics as skm\n",
"import math\n",
"from torch.utils.data.dataset import Dataset\n",
"import random\n",
"import numpy as np\n",
"import torch\n",
"import time\n",
"\n",
"# from data.utils import seq_padding, index_seg, position_idx, age_vocab, random_mask, code2index\n",
"# from sklearn.metrics import roc_auc_score"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# File Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"file_config = {\n",
" 'vocab':'', # token2idx idx2token\n",
" 'train': '',\n",
" 'test': '',\n",
"}\n",
"\n",
"optim_config = {\n",
" 'lr': 3e-5,\n",
" 'warmup_proportion': 0.1,\n",
" 'weight_decay': 0.01\n",
"}\n",
"\n",
"global_params = {\n",
" 'batch_size': 128,\n",
" 'gradient_accumulation_steps': 1,\n",
" 'device': 'cuda:1',\n",
" 'output_dir': '', # output folder\n",
" 'best_name': '', # output model name\n",
" 'save_model': True,\n",
" 'max_len_seq': 100,\n",
" 'max_age': 110,\n",
" 'month': 1,\n",
" 'age_symbol': None,\n",
" 'min_visit': 5\n",
"}\n",
"\n",
"pretrainModel = '' # pretrained MLM path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"create_folder(global_params['output_dir'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"BertVocab = load_obj(file_config['vocab'])\n",
"ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def format_label_vocab(token2idx):\n",
" token2idx = token2idx.copy()\n",
" del token2idx['PAD']\n",
" del token2idx['SEP']\n",
" del token2idx['CLS']\n",
" del token2idx['MASK']\n",
" token = list(token2idx.keys())\n",
" labelVocab = {}\n",
" for i,x in enumerate(token):\n",
" labelVocab[x] = i\n",
" return labelVocab\n",
"\n",
"Vocab_diag = format_label_vocab(BertVocab['token2idx'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_config = {\n",
" 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
" 'hidden_size': 288, # word embedding and seg embedding hidden size\n",
" 'seg_vocab_size': 2, # number of vocab for seg embedding\n",
" 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
" 'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
" 'hidden_dropout_prob': 0.2, # dropout rate\n",
" 'num_hidden_layers': 6, # number of multi-head attention layers required\n",
" 'num_attention_heads': 12, # number of attention heads\n",
" 'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate\n",
" 'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
" 'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
" 'initializer_range': 0.02, # parameter weight initializer range\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set Up Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NextVisit(Dataset):\n",
" def __init__(self, token2idx, diag2idx, age2idx,dataframe, max_len, max_age=110, min_visit=5):\n",
" # dataframe preproecssing\n",
" # filter out the patient with number of visits less than min_visit\n",
" self.vocab = token2idx\n",
" self.label_vocab = diag2idx\n",
" self.max_len = max_len\n",
" self.code = dataframe.code\n",
" self.age = dataframe.age\n",
" self.label = dataframe.label\n",
" self.patid = dataframe.patid\n",
"\n",
" self.age2idx = age2idx\n",
"\n",
" def __getitem__(self, index):\n",
" \"\"\"\n",
" return: age, code, position, segmentation, mask, label\n",
" \"\"\"\n",
" # cut data\n",
" age = self.age[index]\n",
" code = self.code[index]\n",
" label = self.label[index]\n",
" patid = self.patid[index]\n",
"\n",
" # extract data\n",
" age = age[(-self.max_len+1):]\n",
" code = code[(-self.max_len+1):]\n",
"\n",
" # avoid data cut with first element to be 'SEP'\n",
" if code[0] != 'SEP':\n",
" code = np.append(np.array(['CLS']), code)\n",
" age = np.append(np.array(age[0]), age)\n",
" else:\n",
" code[0] = 'CLS'\n",
"\n",
" # mask 0:len(code) to 1, padding to be 0\n",
" mask = np.ones(self.max_len)\n",
" mask[len(code):] = 0\n",
"\n",
" # pad age sequence and code sequence\n",
" age = seq_padding(age, self.max_len, token2idx=self.age2idx)\n",
"\n",
" tokens, code = code2index(code, self.vocab)\n",
" _, label = code2index(label, self.label_vocab)\n",
"\n",
" # get position code and segment code\n",
" tokens = seq_padding(tokens, self.max_len)\n",
" position = position_idx(tokens)\n",
" segment = index_seg(tokens)\n",
"\n",
" # pad code and label\n",
" code = seq_padding(code, self.max_len, symbol=self.vocab['PAD'])\n",
" label = seq_padding(label, self.max_len, symbol=-1)\n",
"\n",
" return torch.LongTensor(age), torch.LongTensor(code), torch.LongTensor(position), torch.LongTensor(segment), \\\n",
" torch.LongTensor(mask), torch.LongTensor(label), torch.LongTensor([int(patid)])\n",
"\n",
" def __len__(self):\n",
" return len(self.code)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class BertConfig(Bert.modeling.BertConfig):\n",
" def __init__(self, config):\n",
" super(BertConfig, self).__init__(\n",
" vocab_size_or_config_json_file=config.get('vocab_size'),\n",
" hidden_size=config['hidden_size'],\n",
" num_hidden_layers=config.get('num_hidden_layers'),\n",
" num_attention_heads=config.get('num_attention_heads'),\n",
" intermediate_size=config.get('intermediate_size'),\n",
" hidden_act=config.get('hidden_act'),\n",
" hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
" attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
" max_position_embeddings = config.get('max_position_embedding'),\n",
" initializer_range=config.get('initializer_range'),\n",
" )\n",
" self.seg_vocab_size = config.get('seg_vocab_size')\n",
" self.age_vocab_size = config.get('age_vocab_size')\n",
"\n",
"class BertEmbeddings(nn.Module):\n",
" \"\"\"Construct the embeddings from word, segment, age\n",
" \"\"\"\n",
"\n",
" def __init__(self, config):\n",
" super(BertEmbeddings, self).__init__()\n",
" self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n",
" self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)\n",
" self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)\n",
" self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\\\n",
" from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))\n",
"\n",
" self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)\n",
" self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
"\n",
" def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):\n",
" if seg_ids is None:\n",
" seg_ids = torch.zeros_like(word_ids)\n",
" if age_ids is None:\n",
" age_ids = torch.zeros_like(word_ids)\n",
" if posi_ids is None:\n",
" posi_ids = torch.zeros_like(word_ids)\n",
"\n",
" word_embed = self.word_embeddings(word_ids)\n",
" segment_embed = self.segment_embeddings(seg_ids)\n",
" age_embed = self.age_embeddings(age_ids)\n",
" posi_embeddings = self.posi_embeddings(posi_ids)\n",
" \n",
" if age:\n",
" embeddings = word_embed + segment_embed + age_embed + posi_embeddings\n",
" else:\n",
" embeddings = word_embed + segment_embed + posi_embeddings\n",
" embeddings = self.LayerNorm(embeddings)\n",
" embeddings = self.dropout(embeddings)\n",
" return embeddings\n",
"\n",
" def _init_posi_embedding(self, max_position_embedding, hidden_size):\n",
" def even_code(pos, idx):\n",
" return np.sin(pos/(10000**(2*idx/hidden_size)))\n",
"\n",
" def odd_code(pos, idx):\n",
" return np.cos(pos/(10000**(2*idx/hidden_size)))\n",
"\n",
" # initialize position embedding table\n",
" lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)\n",
"\n",
" # reset table parameters with hard encoding\n",
" # set even dimension\n",
" for pos in range(max_position_embedding):\n",
" for idx in np.arange(0, hidden_size, step=2):\n",
" lookup_table[pos, idx] = even_code(pos, idx)\n",
" # set odd dimension\n",
" for pos in range(max_position_embedding):\n",
" for idx in np.arange(1, hidden_size, step=2):\n",
" lookup_table[pos, idx] = odd_code(pos, idx)\n",
"\n",
" return torch.tensor(lookup_table)\n",
"\n",
"class BertModel(Bert.modeling.BertPreTrainedModel):\n",
" def __init__(self, config):\n",
" super(BertModel, self).__init__(config)\n",
" self.embeddings = BertEmbeddings(config=config)\n",
" self.encoder = Bert.modeling.BertEncoder(config=config)\n",
" self.pooler = Bert.modeling.BertPooler(config)\n",
" self.apply(self.init_bert_weights)\n",
"\n",
" def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):\n",
" if attention_mask is None:\n",
" attention_mask = torch.ones_like(input_ids)\n",
" if age_ids is None:\n",
" age_ids = torch.zeros_like(input_ids)\n",
" if seg_ids is None:\n",
" seg_ids = torch.zeros_like(input_ids)\n",
" if posi_ids is None:\n",
" posi_ids = torch.zeros_like(input_ids)\n",
"\n",
" extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)\n",
" extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility\n",
" extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
"\n",
" embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)\n",
" encoded_layers = self.encoder(embedding_output,\n",
" extended_attention_mask,\n",
" output_all_encoded_layers=output_all_encoded_layers)\n",
" sequence_output = encoded_layers[-1]\n",
" pooled_output = self.pooler(sequence_output)\n",
" if not output_all_encoded_layers:\n",
" encoded_layers = encoded_layers[-1]\n",
" return encoded_layers, pooled_output\n",
"\n",
"class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):\n",
" def __init__(self, config, num_labels):\n",
" super(BertForMultiLabelPrediction, self).__init__(config)\n",
" self.num_labels = num_labels\n",
" self.bert = BertModel(config)\n",
" self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
" self.classifier = nn.Linear(config.hidden_size, num_labels)\n",
" self.apply(self.init_bert_weights)\n",
"\n",
" def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):\n",
" _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,\n",
" output_all_encoded_layers=False)\n",
" pooled_output = self.dropout(pooled_output)\n",
" logits = self.classifier(pooled_output)\n",
"\n",
" if labels is not None:\n",
" loss_fct = nn.MultiLabelSoftMarginLoss()\n",
" loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))\n",
" return loss, logits\n",
" else:\n",
" return logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_parquet(file_config['train']).reset_index(drop=True)\n",
"data['label'] = data.label.apply(lambda x: list(set(x)))\n",
"Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
"trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_parquet(file_config['test']).reset_index(drop=True)\n",
"data['label'] = data.label.apply(lambda x: list(set(x)))\n",
"Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
"testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set Up Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# del model\n",
"conf = BertConfig(model_config)\n",
"model = BertForMultiLabelPrediction(conf, num_labels=len(Vocab_diag.keys()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# load pretrained model and update weights\n",
"pretrained_dict = torch.load(pretrainModel)\n",
"model_dict = model.state_dict()\n",
"# 1. filter out unnecessary keys\n",
"pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
"# 2. overwrite entries in the existing state dict\n",
"model_dict.update(pretrained_dict) \n",
"# 3. load the new state dict\n",
"model.load_state_dict(model_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = model.to(global_params['device'])\n",
"optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluation Matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sklearn\n",
"def precision(logits, label):\n",
" sig = nn.Sigmoid()\n",
" output=sig(logits)\n",
" label, output=label.cpu(), output.detach().cpu()\n",
" tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
" return tempprc, output, label\n",
"\n",
"def precision_test(logits, label):\n",
" sig = nn.Sigmoid()\n",
" output=sig(logits)\n",
" tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
"# roc = sklearn.metrics.roc_auc_score()\n",
" return tempprc, output, label\n",
"\n",
"def auroc_test(logits, label):\n",
" sig = nn.Sigmoid()\n",
" output=sig(logits)\n",
" tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n",
"# roc = sklearn.metrics.roc_auc_score()\n",
" return tempprc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi-hot Label Encoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import MultiLabelBinarizer\n",
"mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))\n",
"mlb.fit([[each] for each in list(Vocab_diag.values())])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Train and Test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(e):\n",
" model.train()\n",
" tr_loss = 0\n",
" temp_loss = 0\n",
" nb_tr_examples, nb_tr_steps = 0, 0\n",
" cnt = 0\n",
" for step, batch in enumerate(trainload):\n",
" cnt +=1\n",
" age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
" targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
" \n",
" age_ids = age_ids.to(global_params['device'])\n",
" input_ids = input_ids.to(global_params['device'])\n",
" posi_ids = posi_ids.to(global_params['device'])\n",
" segment_ids = segment_ids.to(global_params['device'])\n",
" attMask = attMask.to(global_params['device'])\n",
" targets = targets.to(global_params['device'])\n",
" \n",
" loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
" \n",
" if global_params['gradient_accumulation_steps'] >1:\n",
" loss = loss/global_params['gradient_accumulation_steps']\n",
" loss.backward()\n",
" \n",
" temp_loss += loss.item()\n",
" tr_loss += loss.item()\n",
" nb_tr_examples += input_ids.size(0)\n",
" nb_tr_steps += 1\n",
" \n",
" if step % 2000==0:\n",
" prec, a, b = precision(logits, targets)\n",
" print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/2000, prec))\n",
" temp_loss = 0\n",
" \n",
" if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
" optim.step()\n",
" optim.zero_grad()\n",
" \n",
"\n",
"def evaluation():\n",
" model.eval()\n",
" y = []\n",
" y_label = []\n",
" for step, batch in enumerate(testload):\n",
" model.eval()\n",
" age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
" targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
" \n",
" age_ids = age_ids.to(global_params['device'])\n",
" input_ids = input_ids.to(global_params['device'])\n",
" posi_ids = posi_ids.to(global_params['device'])\n",
" segment_ids = segment_ids.to(global_params['device'])\n",
" attMask = attMask.to(global_params['device'])\n",
" targets = targets.to(global_params['device'])\n",
" \n",
" with torch.no_grad():\n",
" loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
" logits = logits.cpu()\n",
" targets = targets.cpu()\n",
"\n",
" y_label.append(targets)\n",
" y.append(logits)\n",
"\n",
" y_label = torch.cat(y_label, dim=0)\n",
" y = torch.cat(y, dim=0)\n",
"\n",
" tempprc, output, label = precision_test(y, y_label)\n",
" auroc = auroc_test(y, y_label)\n",
" return tempprc, auroc"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(action='ignore')\n",
"optim_config = {\n",
" 'lr': 9e-6,\n",
" 'warmup_proportion': 0.1\n",
"}\n",
"optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)\n",
"\n",
"best_pre = 0\n",
"for e in range(50):\n",
" train(e)\n",
" auc, roc= evaluation()\n",
" if auc >best_pre:\n",
" # Save a trained model\n",
" print(\"** ** * Saving fine - tuned model ** ** * \")\n",
" model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n",
" output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n",
" create_folder(global_params['output_dir'])\n",
" if global_params['save_model']:\n",
" torch.save(model_to_save.state_dict(), output_model_file)\n",
" best_pre = auc\n",
" print('precision : {}, auroc: {},'.format(auc, roc))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}