[bad60c]: / task / MLM.ipynb

Download this file

298 lines (297 with data), 10.2 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys \n",
    "sys.path.insert(0, '../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from common.common import create_folder\n",
    "from common.pytorch import load_model\n",
    "import pytorch_pretrained_bert as Bert\n",
    "from model.utils import age_vocab\n",
    "from common.common import load_obj\n",
    "from dataLoader.MLM import MLMLoader\n",
    "from torch.utils.data import DataLoader\n",
    "import pandas as pd\n",
    "from model.MLM import BertForMaskedLM\n",
    "from model.optimiser import adam\n",
    "import sklearn.metrics as skm\n",
    "import numpy as np\n",
    "import torch\n",
    "import time\n",
    "import torch.nn as nn\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertConfig(Bert.modeling.BertConfig):\n",
    "    def __init__(self, config):\n",
    "        super(BertConfig, self).__init__(\n",
    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
    "            hidden_size=config['hidden_size'],\n",
    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
    "            num_attention_heads=config.get('num_attention_heads'),\n",
    "            intermediate_size=config.get('intermediate_size'),\n",
    "            hidden_act=config.get('hidden_act'),\n",
    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
    "            max_position_embeddings = config.get('max_position_embedding'),\n",
    "            initializer_range=config.get('initializer_range'),\n",
    "        )\n",
    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
    "        self.age_vocab_size = config.get('age_vocab_size')\n",
    "        \n",
    "class TrainConfig(object):\n",
    "    def __init__(self, config):\n",
    "        self.batch_size = config.get('batch_size')\n",
    "        self.use_cuda = config.get('use_cuda')\n",
    "        self.max_len_seq = config.get('max_len_seq')\n",
    "        self.train_loader_workers = config.get('train_loader_workers')\n",
    "        self.test_loader_workers = config.get('test_loader_workers')\n",
    "        self.device = config.get('device')\n",
    "        self.output_dir = config.get('output_dir')\n",
    "        self.output_name = config.get('output_name')\n",
    "        self.best_name = config.get('best_name')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_config = {\n",
    "    'vocab':'',  # vocabulary idx2token, token2idx\n",
    "    'data': '',  # formated data \n",
    "    'model_path': '', # where to save model\n",
    "    'model_name': '', # model name\n",
    "    'file_name': '',  # log path\n",
    "}\n",
    "create_folder(file_config['model_path'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_params = {\n",
    "    'max_seq_len': 64,\n",
    "    'max_age': 110,\n",
    "    'month': 1,\n",
    "    'age_symbol': None,\n",
    "    'min_visit': 5,\n",
    "    'gradient_accumulation_steps': 1\n",
    "}\n",
    "\n",
    "optim_param = {\n",
    "    'lr': 3e-5,\n",
    "    'warmup_proportion': 0.1,\n",
    "    'weight_decay': 0.01\n",
    "}\n",
    "\n",
    "train_params = {\n",
    "    'batch_size': 256,\n",
    "    'use_cuda': True,\n",
    "    'max_len_seq': global_params['max_seq_len'],\n",
    "    'device': 'cuda:0'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BertVocab = load_obj(file_config['vocab'])\n",
    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_parquet(file_config['data'])\n",
    "# remove patients with visits less than min visit\n",
    "data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))\n",
    "data = data[data['length'] >= global_params['min_visit']]\n",
    "data = data.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')\n",
    "trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = {\n",
    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
    "    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens\n",
    "    'hidden_dropout_prob': 0.1, # dropout rate\n",
    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
    "    'num_attention_heads': 12, # number of attention heads\n",
    "    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n",
    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
    "    'initializer_range': 0.02, # parameter weight initializer range\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = BertConfig(model_config)\n",
    "model = BertForMaskedLM(conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(train_params['device'])\n",
    "optim = adam(params=list(model.named_parameters()), config=optim_param)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cal_acc(label, pred):\n",
    "    logs = nn.LogSoftmax()\n",
    "    label=label.cpu().numpy()\n",
    "    ind = np.where(label!=-1)[0]\n",
    "    truepred = pred.detach().cpu().numpy()\n",
    "    truepred = truepred[ind]\n",
    "    truelabel = label[ind]\n",
    "    truepred = logs(torch.tensor(truepred))\n",
    "    outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]\n",
    "    precision = skm.precision_score(truelabel, outs, average='micro')\n",
    "    return precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(e, loader):\n",
    "    tr_loss = 0\n",
    "    temp_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    cnt= 0\n",
    "    start = time.time()\n",
    "\n",
    "    for step, batch in enumerate(loader):\n",
    "        cnt +=1\n",
    "        batch = tuple(t.to(train_params['device']) for t in batch)\n",
    "        age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch\n",
    "        loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)\n",
    "        if global_params['gradient_accumulation_steps'] >1:\n",
    "            loss = loss/global_params['gradient_accumulation_steps']\n",
    "        loss.backward()\n",
    "        \n",
    "        temp_loss += loss.item()\n",
    "        tr_loss += loss.item()\n",
    "        \n",
    "        nb_tr_examples += input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        \n",
    "        if step % 200==0:\n",
    "            print(\"epoch: {}\\t| cnt: {}\\t|Loss: {}\\t| precision: {:.4f}\\t| time: {:.2f}\".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))\n",
    "            temp_loss = 0\n",
    "            start = time.time()\n",
    "            \n",
    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
    "            optim.step()\n",
    "            optim.zero_grad()\n",
    "\n",
    "    print(\"** ** * Saving fine - tuned model ** ** * \")\n",
    "    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
    "    create_folder(file_config['model_path'])\n",
    "    output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])\n",
    "\n",
    "    torch.save(model_to_save.state_dict(), output_model_file)\n",
    "        \n",
    "    cost = time.time() - start\n",
    "    return tr_loss, cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = open(os.path.join(file_config['model_path'], file_config['file_name']), \"w\")\n",
    "f.write('{}\\t{}\\t{}\\n'.format('epoch', 'loss', 'time'))\n",
    "for e in range(50):\n",
    "    loss, time_cost = train(e, trainload)\n",
    "    loss = loss/data_len\n",
    "    f.write('{}\\t{}\\t{}\\n'.format(e, loss, time_cost))\n",
    "f.close()    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}