--- a +++ b/task/MLM.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.insert(0, '../')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from common.common import create_folder\n", + "from common.pytorch import load_model\n", + "import pytorch_pretrained_bert as Bert\n", + "from model.utils import age_vocab\n", + "from common.common import load_obj\n", + "from dataLoader.MLM import MLMLoader\n", + "from torch.utils.data import DataLoader\n", + "import pandas as pd\n", + "from model.MLM import BertForMaskedLM\n", + "from model.optimiser import adam\n", + "import sklearn.metrics as skm\n", + "import numpy as np\n", + "import torch\n", + "import time\n", + "import torch.nn as nn\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class BertConfig(Bert.modeling.BertConfig):\n", + " def __init__(self, config):\n", + " super(BertConfig, self).__init__(\n", + " vocab_size_or_config_json_file=config.get('vocab_size'),\n", + " hidden_size=config['hidden_size'],\n", + " num_hidden_layers=config.get('num_hidden_layers'),\n", + " num_attention_heads=config.get('num_attention_heads'),\n", + " intermediate_size=config.get('intermediate_size'),\n", + " hidden_act=config.get('hidden_act'),\n", + " hidden_dropout_prob=config.get('hidden_dropout_prob'),\n", + " attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n", + " max_position_embeddings = config.get('max_position_embedding'),\n", + " initializer_range=config.get('initializer_range'),\n", + " )\n", + " self.seg_vocab_size = config.get('seg_vocab_size')\n", + " self.age_vocab_size = config.get('age_vocab_size')\n", + " \n", + "class TrainConfig(object):\n", + " def __init__(self, config):\n", + " self.batch_size = config.get('batch_size')\n", + " self.use_cuda = config.get('use_cuda')\n", + " self.max_len_seq = config.get('max_len_seq')\n", + " self.train_loader_workers = config.get('train_loader_workers')\n", + " self.test_loader_workers = config.get('test_loader_workers')\n", + " self.device = config.get('device')\n", + " self.output_dir = config.get('output_dir')\n", + " self.output_name = config.get('output_name')\n", + " self.best_name = config.get('best_name')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "file_config = {\n", + " 'vocab':'', # vocabulary idx2token, token2idx\n", + " 'data': '', # formated data \n", + " 'model_path': '', # where to save model\n", + " 'model_name': '', # model name\n", + " 'file_name': '', # log path\n", + "}\n", + "create_folder(file_config['model_path'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "global_params = {\n", + " 'max_seq_len': 64,\n", + " 'max_age': 110,\n", + " 'month': 1,\n", + " 'age_symbol': None,\n", + " 'min_visit': 5,\n", + " 'gradient_accumulation_steps': 1\n", + "}\n", + "\n", + "optim_param = {\n", + " 'lr': 3e-5,\n", + " 'warmup_proportion': 0.1,\n", + " 'weight_decay': 0.01\n", + "}\n", + "\n", + "train_params = {\n", + " 'batch_size': 256,\n", + " 'use_cuda': True,\n", + " 'max_len_seq': global_params['max_seq_len'],\n", + " 'device': 'cuda:0'\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "BertVocab = load_obj(file_config['vocab'])\n", + "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_parquet(file_config['data'])\n", + "# remove patients with visits less than min visit\n", + "data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))\n", + "data = data[data['length'] >= global_params['min_visit']]\n", + "data = data.reset_index(drop=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')\n", + "trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_config = {\n", + " 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n", + " 'hidden_size': 288, # word embedding and seg embedding hidden size\n", + " 'seg_vocab_size': 2, # number of vocab for seg embedding\n", + " 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n", + " 'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens\n", + " 'hidden_dropout_prob': 0.1, # dropout rate\n", + " 'num_hidden_layers': 6, # number of multi-head attention layers required\n", + " 'num_attention_heads': 12, # number of attention heads\n", + " 'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n", + " 'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n", + " 'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n", + " 'initializer_range': 0.02, # parameter weight initializer range\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conf = BertConfig(model_config)\n", + "model = BertForMaskedLM(conf)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = model.to(train_params['device'])\n", + "optim = adam(params=list(model.named_parameters()), config=optim_param)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def cal_acc(label, pred):\n", + " logs = nn.LogSoftmax()\n", + " label=label.cpu().numpy()\n", + " ind = np.where(label!=-1)[0]\n", + " truepred = pred.detach().cpu().numpy()\n", + " truepred = truepred[ind]\n", + " truelabel = label[ind]\n", + " truepred = logs(torch.tensor(truepred))\n", + " outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]\n", + " precision = skm.precision_score(truelabel, outs, average='micro')\n", + " return precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def train(e, loader):\n", + " tr_loss = 0\n", + " temp_loss = 0\n", + " nb_tr_examples, nb_tr_steps = 0, 0\n", + " cnt= 0\n", + " start = time.time()\n", + "\n", + " for step, batch in enumerate(loader):\n", + " cnt +=1\n", + " batch = tuple(t.to(train_params['device']) for t in batch)\n", + " age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch\n", + " loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)\n", + " if global_params['gradient_accumulation_steps'] >1:\n", + " loss = loss/global_params['gradient_accumulation_steps']\n", + " loss.backward()\n", + " \n", + " temp_loss += loss.item()\n", + " tr_loss += loss.item()\n", + " \n", + " nb_tr_examples += input_ids.size(0)\n", + " nb_tr_steps += 1\n", + " \n", + " if step % 200==0:\n", + " print(\"epoch: {}\\t| cnt: {}\\t|Loss: {}\\t| precision: {:.4f}\\t| time: {:.2f}\".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))\n", + " temp_loss = 0\n", + " start = time.time()\n", + " \n", + " if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", + " optim.step()\n", + " optim.zero_grad()\n", + "\n", + " print(\"** ** * Saving fine - tuned model ** ** * \")\n", + " model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n", + " create_folder(file_config['model_path'])\n", + " output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])\n", + "\n", + " torch.save(model_to_save.state_dict(), output_model_file)\n", + " \n", + " cost = time.time() - start\n", + " return tr_loss, cost" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f = open(os.path.join(file_config['model_path'], file_config['file_name']), \"w\")\n", + "f.write('{}\\t{}\\t{}\\n'.format('epoch', 'loss', 'time'))\n", + "for e in range(50):\n", + " loss, time_cost = train(e, trainload)\n", + " loss = loss/data_len\n", + " f.write('{}\\t{}\\t{}\\n'.format(e, loss, time_cost))\n", + "f.close() " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}