Diff of /task/MLM.ipynb [000000] .. [bad60c]

Switch to side-by-side view

--- a
+++ b/task/MLM.ipynb
@@ -0,0 +1,297 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys \n",
+    "sys.path.insert(0, '../')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from common.common import create_folder\n",
+    "from common.pytorch import load_model\n",
+    "import pytorch_pretrained_bert as Bert\n",
+    "from model.utils import age_vocab\n",
+    "from common.common import load_obj\n",
+    "from dataLoader.MLM import MLMLoader\n",
+    "from torch.utils.data import DataLoader\n",
+    "import pandas as pd\n",
+    "from model.MLM import BertForMaskedLM\n",
+    "from model.optimiser import adam\n",
+    "import sklearn.metrics as skm\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import time\n",
+    "import torch.nn as nn\n",
+    "import os"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BertConfig(Bert.modeling.BertConfig):\n",
+    "    def __init__(self, config):\n",
+    "        super(BertConfig, self).__init__(\n",
+    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
+    "            hidden_size=config['hidden_size'],\n",
+    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
+    "            num_attention_heads=config.get('num_attention_heads'),\n",
+    "            intermediate_size=config.get('intermediate_size'),\n",
+    "            hidden_act=config.get('hidden_act'),\n",
+    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
+    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
+    "            max_position_embeddings = config.get('max_position_embedding'),\n",
+    "            initializer_range=config.get('initializer_range'),\n",
+    "        )\n",
+    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
+    "        self.age_vocab_size = config.get('age_vocab_size')\n",
+    "        \n",
+    "class TrainConfig(object):\n",
+    "    def __init__(self, config):\n",
+    "        self.batch_size = config.get('batch_size')\n",
+    "        self.use_cuda = config.get('use_cuda')\n",
+    "        self.max_len_seq = config.get('max_len_seq')\n",
+    "        self.train_loader_workers = config.get('train_loader_workers')\n",
+    "        self.test_loader_workers = config.get('test_loader_workers')\n",
+    "        self.device = config.get('device')\n",
+    "        self.output_dir = config.get('output_dir')\n",
+    "        self.output_name = config.get('output_name')\n",
+    "        self.best_name = config.get('best_name')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "file_config = {\n",
+    "    'vocab':'',  # vocabulary idx2token, token2idx\n",
+    "    'data': '',  # formated data \n",
+    "    'model_path': '', # where to save model\n",
+    "    'model_name': '', # model name\n",
+    "    'file_name': '',  # log path\n",
+    "}\n",
+    "create_folder(file_config['model_path'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "global_params = {\n",
+    "    'max_seq_len': 64,\n",
+    "    'max_age': 110,\n",
+    "    'month': 1,\n",
+    "    'age_symbol': None,\n",
+    "    'min_visit': 5,\n",
+    "    'gradient_accumulation_steps': 1\n",
+    "}\n",
+    "\n",
+    "optim_param = {\n",
+    "    'lr': 3e-5,\n",
+    "    'warmup_proportion': 0.1,\n",
+    "    'weight_decay': 0.01\n",
+    "}\n",
+    "\n",
+    "train_params = {\n",
+    "    'batch_size': 256,\n",
+    "    'use_cuda': True,\n",
+    "    'max_len_seq': global_params['max_seq_len'],\n",
+    "    'device': 'cuda:0'\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "BertVocab = load_obj(file_config['vocab'])\n",
+    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data = pd.read_parquet(file_config['data'])\n",
+    "# remove patients with visits less than min visit\n",
+    "data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))\n",
+    "data = data[data['length'] >= global_params['min_visit']]\n",
+    "data = data.reset_index(drop=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')\n",
+    "trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_config = {\n",
+    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
+    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
+    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
+    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
+    "    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens\n",
+    "    'hidden_dropout_prob': 0.1, # dropout rate\n",
+    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
+    "    'num_attention_heads': 12, # number of attention heads\n",
+    "    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n",
+    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
+    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
+    "    'initializer_range': 0.02, # parameter weight initializer range\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "conf = BertConfig(model_config)\n",
+    "model = BertForMaskedLM(conf)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = model.to(train_params['device'])\n",
+    "optim = adam(params=list(model.named_parameters()), config=optim_param)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def cal_acc(label, pred):\n",
+    "    logs = nn.LogSoftmax()\n",
+    "    label=label.cpu().numpy()\n",
+    "    ind = np.where(label!=-1)[0]\n",
+    "    truepred = pred.detach().cpu().numpy()\n",
+    "    truepred = truepred[ind]\n",
+    "    truelabel = label[ind]\n",
+    "    truepred = logs(torch.tensor(truepred))\n",
+    "    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]\n",
+    "    precision = skm.precision_score(truelabel, outs, average='micro')\n",
+    "    return precision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(e, loader):\n",
+    "    tr_loss = 0\n",
+    "    temp_loss = 0\n",
+    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
+    "    cnt= 0\n",
+    "    start = time.time()\n",
+    "\n",
+    "    for step, batch in enumerate(loader):\n",
+    "        cnt +=1\n",
+    "        batch = tuple(t.to(train_params['device']) for t in batch)\n",
+    "        age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch\n",
+    "        loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)\n",
+    "        if global_params['gradient_accumulation_steps'] >1:\n",
+    "            loss = loss/global_params['gradient_accumulation_steps']\n",
+    "        loss.backward()\n",
+    "        \n",
+    "        temp_loss += loss.item()\n",
+    "        tr_loss += loss.item()\n",
+    "        \n",
+    "        nb_tr_examples += input_ids.size(0)\n",
+    "        nb_tr_steps += 1\n",
+    "        \n",
+    "        if step % 200==0:\n",
+    "            print(\"epoch: {}\\t| cnt: {}\\t|Loss: {}\\t| precision: {:.4f}\\t| time: {:.2f}\".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))\n",
+    "            temp_loss = 0\n",
+    "            start = time.time()\n",
+    "            \n",
+    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
+    "            optim.step()\n",
+    "            optim.zero_grad()\n",
+    "\n",
+    "    print(\"** ** * Saving fine - tuned model ** ** * \")\n",
+    "    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
+    "    create_folder(file_config['model_path'])\n",
+    "    output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])\n",
+    "\n",
+    "    torch.save(model_to_save.state_dict(), output_model_file)\n",
+    "        \n",
+    "    cost = time.time() - start\n",
+    "    return tr_loss, cost"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "f = open(os.path.join(file_config['model_path'], file_config['file_name']), \"w\")\n",
+    "f.write('{}\\t{}\\t{}\\n'.format('epoch', 'loss', 'time'))\n",
+    "for e in range(50):\n",
+    "    loss, time_cost = train(e, trainload)\n",
+    "    loss = loss/data_len\n",
+    "    f.write('{}\\t{}\\t{}\\n'.format(e, loss, time_cost))\n",
+    "f.close()    "
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}