--- a
+++ b/task/NextVIsit-12month.ipynb
@@ -0,0 +1,621 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sys \n",
+    "sys.path.insert(0, '../')\n",
+    "\n",
+    "from common.common import create_folder,load_obj\n",
+    "# from data import bert,dataframe,utils\n",
+    "from dataLoader.utils import seq_padding,code2index, position_idx, index_seg\n",
+    "from torch.utils.data import DataLoader\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "from torch.utils.data.dataset import Dataset\n",
+    "import os\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import pytorch_pretrained_bert as Bert\n",
+    "from model.utils import age_vocab\n",
+    "from model import optimiser\n",
+    "import sklearn.metrics as skm\n",
+    "import math\n",
+    "from torch.utils.data.dataset import Dataset\n",
+    "import random\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import time\n",
+    "\n",
+    "# from data.utils import seq_padding, index_seg, position_idx, age_vocab, random_mask, code2index\n",
+    "# from sklearn.metrics import roc_auc_score"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# File Parameters"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "file_config = {\n",
+    "    'vocab':'',  # token2idx idx2token\n",
+    "    'train': '',\n",
+    "    'test': '',\n",
+    "}\n",
+    "\n",
+    "optim_config = {\n",
+    "    'lr': 3e-5,\n",
+    "    'warmup_proportion': 0.1,\n",
+    "    'weight_decay': 0.01\n",
+    "}\n",
+    "\n",
+    "global_params = {\n",
+    "    'batch_size': 128,\n",
+    "    'gradient_accumulation_steps': 1,\n",
+    "    'device': 'cuda:1',\n",
+    "    'output_dir': '',  # output folder\n",
+    "    'best_name': '', # output model name\n",
+    "    'save_model': True,\n",
+    "    'max_len_seq': 100,\n",
+    "    'max_age': 110,\n",
+    "    'month': 1,\n",
+    "    'age_symbol': None,\n",
+    "    'min_visit': 5\n",
+    "}\n",
+    "\n",
+    "pretrainModel = '' # pretrained MLM path"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "create_folder(global_params['output_dir'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "BertVocab = load_obj(file_config['vocab'])\n",
+    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def format_label_vocab(token2idx):\n",
+    "    token2idx = token2idx.copy()\n",
+    "    del token2idx['PAD']\n",
+    "    del token2idx['SEP']\n",
+    "    del token2idx['CLS']\n",
+    "    del token2idx['MASK']\n",
+    "    token = list(token2idx.keys())\n",
+    "    labelVocab = {}\n",
+    "    for i,x in enumerate(token):\n",
+    "        labelVocab[x] = i\n",
+    "    return labelVocab\n",
+    "\n",
+    "Vocab_diag = format_label_vocab(BertVocab['token2idx'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model_config = {\n",
+    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
+    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
+    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
+    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
+    "    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
+    "    'hidden_dropout_prob': 0.2, # dropout rate\n",
+    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
+    "    'num_attention_heads': 12, # number of attention heads\n",
+    "    'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate\n",
+    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
+    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
+    "    'initializer_range': 0.02, # parameter weight initializer range\n",
+    "}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Set Up Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class NextVisit(Dataset):\n",
+    "    def __init__(self, token2idx, diag2idx, age2idx,dataframe, max_len, max_age=110, min_visit=5):\n",
+    "        # dataframe preproecssing\n",
+    "        # filter out the patient with number of visits less than min_visit\n",
+    "        self.vocab = token2idx\n",
+    "        self.label_vocab = diag2idx\n",
+    "        self.max_len = max_len\n",
+    "        self.code = dataframe.code\n",
+    "        self.age = dataframe.age\n",
+    "        self.label = dataframe.label\n",
+    "        self.patid = dataframe.patid\n",
+    "\n",
+    "        self.age2idx = age2idx\n",
+    "\n",
+    "    def __getitem__(self, index):\n",
+    "        \"\"\"\n",
+    "        return: age, code, position, segmentation, mask, label\n",
+    "        \"\"\"\n",
+    "        # cut data\n",
+    "        age = self.age[index]\n",
+    "        code = self.code[index]\n",
+    "        label = self.label[index]\n",
+    "        patid = self.patid[index]\n",
+    "\n",
+    "        # extract data\n",
+    "        age = age[(-self.max_len+1):]\n",
+    "        code = code[(-self.max_len+1):]\n",
+    "\n",
+    "        # avoid data cut with first element to be 'SEP'\n",
+    "        if code[0] != 'SEP':\n",
+    "            code = np.append(np.array(['CLS']), code)\n",
+    "            age = np.append(np.array(age[0]), age)\n",
+    "        else:\n",
+    "            code[0] = 'CLS'\n",
+    "\n",
+    "        # mask 0:len(code) to 1, padding to be 0\n",
+    "        mask = np.ones(self.max_len)\n",
+    "        mask[len(code):] = 0\n",
+    "\n",
+    "        # pad age sequence and code sequence\n",
+    "        age = seq_padding(age, self.max_len, token2idx=self.age2idx)\n",
+    "\n",
+    "        tokens, code = code2index(code, self.vocab)\n",
+    "        _, label = code2index(label, self.label_vocab)\n",
+    "\n",
+    "        # get position code and segment code\n",
+    "        tokens = seq_padding(tokens, self.max_len)\n",
+    "        position = position_idx(tokens)\n",
+    "        segment = index_seg(tokens)\n",
+    "\n",
+    "        # pad code and label\n",
+    "        code = seq_padding(code, self.max_len, symbol=self.vocab['PAD'])\n",
+    "        label = seq_padding(label, self.max_len, symbol=-1)\n",
+    "\n",
+    "        return torch.LongTensor(age), torch.LongTensor(code), torch.LongTensor(position), torch.LongTensor(segment), \\\n",
+    "               torch.LongTensor(mask), torch.LongTensor(label), torch.LongTensor([int(patid)])\n",
+    "\n",
+    "    def __len__(self):\n",
+    "        return len(self.code)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class BertConfig(Bert.modeling.BertConfig):\n",
+    "    def __init__(self, config):\n",
+    "        super(BertConfig, self).__init__(\n",
+    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
+    "            hidden_size=config['hidden_size'],\n",
+    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
+    "            num_attention_heads=config.get('num_attention_heads'),\n",
+    "            intermediate_size=config.get('intermediate_size'),\n",
+    "            hidden_act=config.get('hidden_act'),\n",
+    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
+    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
+    "            max_position_embeddings = config.get('max_position_embedding'),\n",
+    "            initializer_range=config.get('initializer_range'),\n",
+    "        )\n",
+    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
+    "        self.age_vocab_size = config.get('age_vocab_size')\n",
+    "\n",
+    "class BertEmbeddings(nn.Module):\n",
+    "    \"\"\"Construct the embeddings from word, segment, age\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, config):\n",
+    "        super(BertEmbeddings, self).__init__()\n",
+    "        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n",
+    "        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)\n",
+    "        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)\n",
+    "        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\\\n",
+    "            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))\n",
+    "\n",
+    "        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)\n",
+    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
+    "\n",
+    "    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):\n",
+    "        if seg_ids is None:\n",
+    "            seg_ids = torch.zeros_like(word_ids)\n",
+    "        if age_ids is None:\n",
+    "            age_ids = torch.zeros_like(word_ids)\n",
+    "        if posi_ids is None:\n",
+    "            posi_ids = torch.zeros_like(word_ids)\n",
+    "\n",
+    "        word_embed = self.word_embeddings(word_ids)\n",
+    "        segment_embed = self.segment_embeddings(seg_ids)\n",
+    "        age_embed = self.age_embeddings(age_ids)\n",
+    "        posi_embeddings = self.posi_embeddings(posi_ids)\n",
+    "        \n",
+    "        if age:\n",
+    "            embeddings = word_embed + segment_embed + age_embed + posi_embeddings\n",
+    "        else:\n",
+    "            embeddings = word_embed + segment_embed + posi_embeddings\n",
+    "        embeddings = self.LayerNorm(embeddings)\n",
+    "        embeddings = self.dropout(embeddings)\n",
+    "        return embeddings\n",
+    "\n",
+    "    def _init_posi_embedding(self, max_position_embedding, hidden_size):\n",
+    "        def even_code(pos, idx):\n",
+    "            return np.sin(pos/(10000**(2*idx/hidden_size)))\n",
+    "\n",
+    "        def odd_code(pos, idx):\n",
+    "            return np.cos(pos/(10000**(2*idx/hidden_size)))\n",
+    "\n",
+    "        # initialize position embedding table\n",
+    "        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)\n",
+    "\n",
+    "        # reset table parameters with hard encoding\n",
+    "        # set even dimension\n",
+    "        for pos in range(max_position_embedding):\n",
+    "            for idx in np.arange(0, hidden_size, step=2):\n",
+    "                lookup_table[pos, idx] = even_code(pos, idx)\n",
+    "        # set odd dimension\n",
+    "        for pos in range(max_position_embedding):\n",
+    "            for idx in np.arange(1, hidden_size, step=2):\n",
+    "                lookup_table[pos, idx] = odd_code(pos, idx)\n",
+    "\n",
+    "        return torch.tensor(lookup_table)\n",
+    "\n",
+    "class BertModel(Bert.modeling.BertPreTrainedModel):\n",
+    "    def __init__(self, config):\n",
+    "        super(BertModel, self).__init__(config)\n",
+    "        self.embeddings = BertEmbeddings(config=config)\n",
+    "        self.encoder = Bert.modeling.BertEncoder(config=config)\n",
+    "        self.pooler = Bert.modeling.BertPooler(config)\n",
+    "        self.apply(self.init_bert_weights)\n",
+    "\n",
+    "    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):\n",
+    "        if attention_mask is None:\n",
+    "            attention_mask = torch.ones_like(input_ids)\n",
+    "        if age_ids is None:\n",
+    "            age_ids = torch.zeros_like(input_ids)\n",
+    "        if seg_ids is None:\n",
+    "            seg_ids = torch.zeros_like(input_ids)\n",
+    "        if posi_ids is None:\n",
+    "            posi_ids = torch.zeros_like(input_ids)\n",
+    "\n",
+    "        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)\n",
+    "        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility\n",
+    "        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
+    "\n",
+    "        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)\n",
+    "        encoded_layers = self.encoder(embedding_output,\n",
+    "                                      extended_attention_mask,\n",
+    "                                      output_all_encoded_layers=output_all_encoded_layers)\n",
+    "        sequence_output = encoded_layers[-1]\n",
+    "        pooled_output = self.pooler(sequence_output)\n",
+    "        if not output_all_encoded_layers:\n",
+    "            encoded_layers = encoded_layers[-1]\n",
+    "        return encoded_layers, pooled_output\n",
+    "\n",
+    "class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):\n",
+    "    def __init__(self, config, num_labels):\n",
+    "        super(BertForMultiLabelPrediction, self).__init__(config)\n",
+    "        self.num_labels = num_labels\n",
+    "        self.bert = BertModel(config)\n",
+    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
+    "        self.classifier = nn.Linear(config.hidden_size, num_labels)\n",
+    "        self.apply(self.init_bert_weights)\n",
+    "\n",
+    "    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):\n",
+    "        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,\n",
+    "                                     output_all_encoded_layers=False)\n",
+    "        pooled_output = self.dropout(pooled_output)\n",
+    "        logits = self.classifier(pooled_output)\n",
+    "\n",
+    "        if labels is not None:\n",
+    "            loss_fct = nn.MultiLabelSoftMarginLoss()\n",
+    "            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))\n",
+    "            return loss, logits\n",
+    "        else:\n",
+    "            return logits"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Load Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data = pd.read_parquet(file_config['train']).reset_index(drop=True)\n",
+    "data['label'] = data.label.apply(lambda x: list(set(x)))\n",
+    "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
+    "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data = pd.read_parquet(file_config['test']).reset_index(drop=True)\n",
+    "data['label'] = data.label.apply(lambda x: list(set(x)))\n",
+    "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
+    "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Set Up Model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# del model\n",
+    "conf = BertConfig(model_config)\n",
+    "model = BertForMultiLabelPrediction(conf, num_labels=len(Vocab_diag.keys()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "# load pretrained model and update weights\n",
+    "pretrained_dict = torch.load(pretrainModel)\n",
+    "model_dict = model.state_dict()\n",
+    "# 1. filter out unnecessary keys\n",
+    "pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
+    "# 2. overwrite entries in the existing state dict\n",
+    "model_dict.update(pretrained_dict) \n",
+    "# 3. load the new state dict\n",
+    "model.load_state_dict(model_dict)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "model = model.to(global_params['device'])\n",
+    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Evaluation Matrix"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sklearn\n",
+    "def precision(logits, label):\n",
+    "    sig = nn.Sigmoid()\n",
+    "    output=sig(logits)\n",
+    "    label, output=label.cpu(), output.detach().cpu()\n",
+    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
+    "    return tempprc, output, label\n",
+    "\n",
+    "def precision_test(logits, label):\n",
+    "    sig = nn.Sigmoid()\n",
+    "    output=sig(logits)\n",
+    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
+    "#     roc = sklearn.metrics.roc_auc_score()\n",
+    "    return tempprc, output, label\n",
+    "\n",
+    "def auroc_test(logits, label):\n",
+    "    sig = nn.Sigmoid()\n",
+    "    output=sig(logits)\n",
+    "    tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n",
+    "#     roc = sklearn.metrics.roc_auc_score()\n",
+    "    return tempprc"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Multi-hot Label Encoder"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.preprocessing import MultiLabelBinarizer\n",
+    "mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))\n",
+    "mlb.fit([[each] for each in list(Vocab_diag.values())])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Train and Test"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def train(e):\n",
+    "    model.train()\n",
+    "    tr_loss = 0\n",
+    "    temp_loss = 0\n",
+    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
+    "    cnt = 0\n",
+    "    for step, batch in enumerate(trainload):\n",
+    "        cnt +=1\n",
+    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
+    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
+    "        \n",
+    "        age_ids = age_ids.to(global_params['device'])\n",
+    "        input_ids = input_ids.to(global_params['device'])\n",
+    "        posi_ids = posi_ids.to(global_params['device'])\n",
+    "        segment_ids = segment_ids.to(global_params['device'])\n",
+    "        attMask = attMask.to(global_params['device'])\n",
+    "        targets = targets.to(global_params['device'])\n",
+    "        \n",
+    "        loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
+    "        \n",
+    "        if global_params['gradient_accumulation_steps'] >1:\n",
+    "            loss = loss/global_params['gradient_accumulation_steps']\n",
+    "        loss.backward()\n",
+    "        \n",
+    "        temp_loss += loss.item()\n",
+    "        tr_loss += loss.item()\n",
+    "        nb_tr_examples += input_ids.size(0)\n",
+    "        nb_tr_steps += 1\n",
+    "        \n",
+    "        if step % 2000==0:\n",
+    "            prec, a, b = precision(logits, targets)\n",
+    "            print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/2000, prec))\n",
+    "            temp_loss = 0\n",
+    "        \n",
+    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
+    "            optim.step()\n",
+    "            optim.zero_grad()\n",
+    "        \n",
+    "\n",
+    "def evaluation():\n",
+    "    model.eval()\n",
+    "    y = []\n",
+    "    y_label = []\n",
+    "    for step, batch in enumerate(testload):\n",
+    "        model.eval()\n",
+    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
+    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
+    "        \n",
+    "        age_ids = age_ids.to(global_params['device'])\n",
+    "        input_ids = input_ids.to(global_params['device'])\n",
+    "        posi_ids = posi_ids.to(global_params['device'])\n",
+    "        segment_ids = segment_ids.to(global_params['device'])\n",
+    "        attMask = attMask.to(global_params['device'])\n",
+    "        targets = targets.to(global_params['device'])\n",
+    "        \n",
+    "        with torch.no_grad():\n",
+    "            loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
+    "        logits = logits.cpu()\n",
+    "        targets = targets.cpu()\n",
+    "\n",
+    "        y_label.append(targets)\n",
+    "        y.append(logits)\n",
+    "\n",
+    "    y_label = torch.cat(y_label, dim=0)\n",
+    "    y = torch.cat(y, dim=0)\n",
+    "\n",
+    "    tempprc, output, label = precision_test(y, y_label)\n",
+    "    auroc = auroc_test(y, y_label)\n",
+    "    return tempprc, auroc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "import warnings\n",
+    "warnings.filterwarnings(action='ignore')\n",
+    "optim_config = {\n",
+    "    'lr': 9e-6,\n",
+    "    'warmup_proportion': 0.1\n",
+    "}\n",
+    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)\n",
+    "\n",
+    "best_pre = 0\n",
+    "for e in range(50):\n",
+    "    train(e)\n",
+    "    auc, roc= evaluation()\n",
+    "    if auc >best_pre:\n",
+    "        # Save a trained model\n",
+    "        print(\"** ** * Saving fine - tuned model ** ** * \")\n",
+    "        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
+    "        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n",
+    "        create_folder(global_params['output_dir'])\n",
+    "        if global_params['save_model']:\n",
+    "            torch.save(model_to_save.state_dict(), output_model_file)\n",
+    "        best_pre = auc\n",
+    "    print('precision : {}, auroc: {},'.format(auc, roc))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}