--- a
+++ b/examples/run_TBEHRT.ipynb
@@ -0,0 +1,931 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "import sys\n",
+    "sys.path.insert(0, '/home/rnshishir/deepmed/TBEHRT_pl/')\n",
+    "\n",
+    "import os\n",
+    "from torch.utils.data import DataLoader\n",
+    "from sklearn.model_selection import KFold\n",
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import pytorch_pretrained_bert as Bert\n",
+    "\n",
+    "from  pytorch_pretrained_bert import optimizer\n",
+    "import sklearn.metrics as skm\n",
+    "from torch.utils.data.dataset import Dataset\n",
+    "from src.utils import *\n",
+    "from src.model import *\n",
+    "from src.data import *\n",
+    "\n",
+    "from torch import optim as toptimizer\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "def get_beta(batch_idx, m, beta_type):\n",
+    "    if beta_type == \"Blundell\":\n",
+    "        beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)\n",
+    "    elif beta_type == \"Soenderby\":\n",
+    "        beta = min(epoch / (num_epochs // 4), 1)\n",
+    "    elif beta_type == \"Standard\":\n",
+    "        beta = 1 / m\n",
+    "    else:\n",
+    "        beta = 0\n",
+    "    return beta\n",
+    "\n",
+    "\n",
+    "def trainunsup(e, sched, patienceMetric, MEM=True):\n",
+    "    sampled = datatrain.reset_index(drop=True)\n",
+    "    #\n",
+    "\n",
+    "    Dset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
+    "                                 age = 'age', year = 'year' , static= 'static' , \n",
+    "                                 max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label',  \n",
+    "                                 yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n",
+    "    \n",
+    "        \n",
+    "        \n",
+    "    trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n",
+    "                           sampler=None)\n",
+    "\n",
+    "    model.train()\n",
+    "    tr_loss = 0\n",
+    "    temp_loss = 0\n",
+    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
+    "    oldloss = 10 ** 10\n",
+    "    for step, batch in enumerate(trainload):\n",
+    "\n",
+    "        batch = tuple(t.to(global_params['device']) for t in batch)\n",
+    "\n",
+    "        age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
+    "\n",
+    "        masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
+    "            input_idsMLM,\n",
+    "            age_ids,\n",
+    "            segment_ids,\n",
+    "            posi_ids,\n",
+    "            year_ids,\n",
+    "\n",
+    "            attention_mask=attMask,\n",
+    "            masked_lm_labels=masked_label,\n",
+    "            outcomeT=outcome_label,\n",
+    "            treatmentCLabel=treatment_label,\n",
+    "            fullEval=False,\n",
+    "            vaelabel=vaelabel)\n",
+    "        vaeloss = vaelosspure['loss']\n",
+    "\n",
+    "        totalL = masked_lm_loss\n",
+    "        if global_params['gradient_accumulation_steps'] > 1:\n",
+    "            totalL = totalL / global_params['gradient_accumulation_steps']\n",
+    "        totalL.backward()\n",
+    "        treatFull = treatOut\n",
+    "        treatLabelFull = treatLabel\n",
+    "        treatLabelFull = treatLabelFull.cpu().detach()\n",
+    "\n",
+    "        outFull = out\n",
+    "\n",
+    "        outLabelFull = outLabel\n",
+    "        treatindex = treatindex.cpu().detach().numpy()\n",
+    "        zeroind = np.where(treatindex == 0)\n",
+    "        outzero = outFull[0][zeroind]\n",
+    "        outzeroLabel = outLabelFull[zeroind]\n",
+    "\n",
+    "\n",
+    "        temp_loss += totalL.item()\n",
+    "        tr_loss += totalL.item()\n",
+    "        nb_tr_examples += input_ids.size(0)\n",
+    "        nb_tr_steps += 1\n",
+    "\n",
+    "        if step % 600 == 0:\n",
+    "            print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n",
+    "                   keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n",
+    "            if oldloss < vaelosspure['loss']:\n",
+    "                patienceMetric = patienceMetric + 1\n",
+    "                if patienceMetric >= 10:\n",
+    "                    sched.step()\n",
+    "                    print(\"LR: \", sched.get_lr())\n",
+    "                    patienceMetric = 0\n",
+    "            oldloss = vaelosspure['loss']\n",
+    "\n",
+    "        if step % 200 == 0:\n",
+    "            precOut0 = -1\n",
+    "            if len(zeroind[0]) > 0:\n",
+    "                precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n",
+    "\n",
+    "            print(\n",
+    "                \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n",
+    "                    e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n",
+    "                    cal_acc(treatLabelFull, treatFull, False)))\n",
+    "            temp_loss = 0\n",
+    "\n",
+    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
+    "            optim.step()\n",
+    "            optim.zero_grad()\n",
+    "\n",
+    "    # Save a trained model\n",
+    "    del sampled, Dset, trainload\n",
+    "    return sched, patienceMetric\n",
+    "\n",
+    "\n",
+    "def train_multi(e, MEM=True):\n",
+    "    sampled = datatrain.reset_index(drop=True)\n",
+    "\n",
+    "    Dset =  TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
+    "                                 age = 'age', year = 'year' , static= 'static' , \n",
+    "                                 max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label',  \n",
+    "                                 yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n",
+    "    \n",
+    "        \n",
+    "        \n",
+    "    trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n",
+    "                           sampler=None)\n",
+    "    \n",
+    "    \n",
+    "    model.train()\n",
+    "    tr_loss = 0\n",
+    "    temp_loss = 0\n",
+    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
+    "    for step, batch in enumerate(trainload):\n",
+    "\n",
+    "        batch = tuple(t.to(global_params['device']) for t in batch)\n",
+    "\n",
+    "        age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
+    "        masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
+    "            input_idsMLM,\n",
+    "            age_ids,\n",
+    "            segment_ids,\n",
+    "            posi_ids,\n",
+    "            year_ids,\n",
+    "\n",
+    "            attention_mask=attMask,\n",
+    "            masked_lm_labels=masked_label,\n",
+    "            outcomeT=outcome_label,\n",
+    "            treatmentCLabel=treatment_label,\n",
+    "            fullEval=False,\n",
+    "            vaelabel=vaelabel)\n",
+    "\n",
+    "        vaeloss = vaelosspure['loss']\n",
+    "        totalL = 1 * (lossT) + 0 + (global_params['fac'] * masked_lm_loss)\n",
+    "        if global_params['gradient_accumulation_steps'] > 1:\n",
+    "            totalL = totalL / global_params['gradient_accumulation_steps']\n",
+    "        totalL.backward()\n",
+    "        treatFull = treatOut\n",
+    "        treatLabelFull = treatLabel\n",
+    "        treatLabelFull = treatLabelFull.cpu().detach()\n",
+    "\n",
+    "        outFull = out\n",
+    "\n",
+    "        outLabelFull = outLabel\n",
+    "        treatindex = treatindex.cpu().detach().numpy()\n",
+    "        zeroind = np.where(treatindex == 0)\n",
+    "        outzero = outFull[0][zeroind]\n",
+    "        outzeroLabel = outLabelFull[zeroind]\n",
+    "\n",
+    "        temp_loss += totalL.item()\n",
+    "        tr_loss += totalL.item()\n",
+    "        nb_tr_examples += input_ids.size(0)\n",
+    "        nb_tr_steps += 1\n",
+    "\n",
+    "        if step % 200 == 0:\n",
+    "            precOut0 = -1\n",
+    "\n",
+    "            if len(zeroind[0]) > 0:\n",
+    "                precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n",
+    "\n",
+    "            print(\n",
+    "                \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n",
+    "                    e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n",
+    "                    cal_acc(treatLabelFull, treatFull, False)))\n",
+    "\n",
+    "            print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n",
+    "                   keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n",
+    "            temp_loss = 0\n",
+    "\n",
+    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
+    "            optim.step()\n",
+    "            optim.zero_grad()\n",
+    "\n",
+    "    del sampled, Dset, trainload\n",
+    "\n",
+    "\n",
+    "def evaluation_multi_repeats():\n",
+    "    model.eval()\n",
+    "    y = []\n",
+    "    y_label = []\n",
+    "    t_label = []\n",
+    "    t_output = []\n",
+    "    count = 0\n",
+    "    totalL = 0\n",
+    "    for step, batch in enumerate(testload):\n",
+    "        model.eval()\n",
+    "        count = count + 1\n",
+    "        batch = tuple(t.to(global_params['device']) for t in batch)\n",
+    "\n",
+    "        age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
+    "        with torch.no_grad():\n",
+    "\n",
+    "            masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
+    "                input_idsMLM,\n",
+    "                age_ids,\n",
+    "                segment_ids,\n",
+    "                posi_ids,\n",
+    "                year_ids,\n",
+    "\n",
+    "                attention_mask=attMask,\n",
+    "                masked_lm_labels=masked_label,\n",
+    "                outcomeT=outcome_label,\n",
+    "                treatmentCLabel=treatment_label, vaelabel=vaelabel)\n",
+    "\n",
+    "        totalL = totalL + lossT.item() + 0 + (global_params['fac'] * masked_lm_loss)\n",
+    "        treatFull = treatOut\n",
+    "        treatLabelFull = treatLabel\n",
+    "        treatLabelFull = treatLabelFull.detach()\n",
+    "        outFull = out\n",
+    "        outLabelFull = outLabel\n",
+    "        treatindex = treatindex.cpu().detach().numpy()\n",
+    "        outPred = []\n",
+    "        outexpLab = []\n",
+    "        for el in range(global_params['treatments']):\n",
+    "            zeroind = np.where(treatindex == el)\n",
+    "            outPred.append(outFull[el][zeroind])\n",
+    "            outexpLab.append(outLabelFull[zeroind])\n",
+    "\n",
+    "\n",
+    "        y_label.append(torch.cat(outexpLab))\n",
+    "\n",
+    "        y.append(torch.cat(outPred))\n",
+    "\n",
+    "        treatOut = treatFull.cpu()\n",
+    "        treatLabel = treatLabelFull.cpu()\n",
+    "        if step % 200 == 0:\n",
+    "            print(step, \"tempLoss:\", totalL / count)\n",
+    "\n",
+    "        t_label.append(treatLabel)\n",
+    "        t_output.append(treatOut)\n",
+    "\n",
+    "    y_label = torch.cat(y_label, dim=0)\n",
+    "    y = torch.cat(y, dim=0)\n",
+    "    t_label = torch.cat(t_label, dim=0)\n",
+    "    treatO = torch.cat(t_output, dim=0)\n",
+    "\n",
+    "    tempprc, output, label = precision_test(y, y_label, False)\n",
+    "    treatPRC = cal_acc(t_label, treatO, False)\n",
+    "    tempprc2, output2, label2 = roc_auc(y, y_label, False)\n",
+    "\n",
+    "    print(\"LossEval: \", float(totalL) / float(count))\n",
+    "\n",
+    "    return tempprc, tempprc2, treatPRC, float(totalL) / float(count)\n",
+    "\n",
+    "\n",
+    "def fullEval_4analysis_multi(tr, te, filetest):\n",
+    "    if tr:\n",
+    "        sampled = datatrain.reset_index(drop=True)\n",
+    "\n",
+    "    if te:\n",
+    "        data = filetest\n",
+    "\n",
+    "        if tr:\n",
+    "            sampled = pd.concat([sampled, data]).reset_index(drop=True)\n",
+    "        else:\n",
+    "            sampled = data\n",
+    "    Fulltset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
+    "                                 age = 'age', year = 'year' , static= 'static' , \n",
+    "                                 max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label',  \n",
+    "                                 yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n",
+    "    \n",
+    "        \n",
+    "        \n",
+    "    fullDataLoad = DataLoader(dataset=Fulltset, batch_size=int(global_params['batch_size']), shuffle=False,\n",
+    "                              num_workers=0)\n",
+    "\n",
+    "    model.eval()\n",
+    "    y = []\n",
+    "    y_label = []\n",
+    "    t_label = []\n",
+    "    t_output = []\n",
+    "    count = 0\n",
+    "    totalL = 0\n",
+    "    eps_array = []\n",
+    "\n",
+    "    for yyy in range(model_config['num_treatment']):\n",
+    "        y.append([yyy])\n",
+    "        y_label.append([yyy])\n",
+    "\n",
+    "    print(y)\n",
+    "    for step, batch in enumerate(fullDataLoad):\n",
+    "        model.eval()\n",
+    "\n",
+    "        count = count + 1\n",
+    "        batch = tuple(t.to(global_params['device']) for t in batch)\n",
+    "\n",
+    "        age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
+    "\n",
+    "        with torch.no_grad():\n",
+    "            masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaeloss = model(\n",
+    "                input_idsMLM,\n",
+    "                age_ids,\n",
+    "                segment_ids,\n",
+    "                posi_ids,\n",
+    "                year_ids,\n",
+    "\n",
+    "                attention_mask=attMask,\n",
+    "                masked_lm_labels=masked_label,\n",
+    "                outcomeT=outcome_label,\n",
+    "                treatmentCLabel=treatment_label, fullEval=True, vaelabel=vaelabel)\n",
+    "\n",
+    "\n",
+    "\n",
+    "        outFull = out\n",
+    "        outLabelFull = outLabel\n",
+    "\n",
+    "\n",
+    "        for el in range(global_params['treatments']):\n",
+    "            y[el].append(outFull[el].cpu())\n",
+    "            y_label[el].append(outLabelFull.cpu())\n",
+    "\n",
+    "        totalL = totalL + (1 * (lossT)).item()\n",
+    "\n",
+    "        if step % 200 == 0:\n",
+    "            print(step, \"tempLoss:\", totalL / count)\n",
+    "\n",
+    "        t_label.append(treatLabel)\n",
+    "        t_output.append(treatOut)\n",
+    "\n",
+    "    for idd, elem in enumerate(y):\n",
+    "        elem = torch.cat(elem[1:], dim=0)\n",
+    "        y[idd] = elem\n",
+    "    for idd, elem in enumerate(y_label):\n",
+    "        elem = torch.cat(elem[1:], dim=0)\n",
+    "        y_label[idd] = elem\n",
+    "\n",
+    "    t_label = torch.cat(t_label, dim=0)\n",
+    "    treatO = torch.cat(t_output, dim=0)\n",
+    "    treatPRC = cal_acc(t_label, treatO)\n",
+    "\n",
+    "    print(\"LossEval: \", float(totalL) / float(count), \"prec treat:\", treatPRC)\n",
+    "    return y, y_label, t_label, treatO, treatPRC, eps_array\n",
+    "\n",
+    "\n",
+    "def fullCONV(y, y_label, t_label, treatO):\n",
+    "    def convert_multihot(label, pred):\n",
+    "        label = label.cpu().numpy()\n",
+    "        truepred = pred.detach().cpu().numpy()\n",
+    "        truelabel = label\n",
+    "        newpred = []\n",
+    "        for i, x in enumerate(truelabel):\n",
+    "            temppred = []\n",
+    "            temppred.append(truepred[i][0])\n",
+    "            temppred.append(truepred[i][x[0]])\n",
+    "            newpred.append(temppred)\n",
+    "        return truelabel, np.array(truepred)\n",
+    "\n",
+    "    def convert_bin(logits, label, treatmentlabel2):\n",
+    "\n",
+    "        output = logits\n",
+    "        label, output = label.cpu().numpy(), output.detach().cpu().numpy()\n",
+    "        label = label[treatmentlabel2[0]]\n",
+    "\n",
+    "        return label, output\n",
+    "\n",
+    "    treatmentlabel2, treatment2 = convert_multihot(t_label, treatO)\n",
+    "    y = torch.cat(y, dim=0).view(global_params['treatments'], -1)\n",
+    "    y = y.transpose(1, 0)\n",
+    "    y_label = torch.cat(y_label, dim=0).view(global_params['treatments'], -1)\n",
+    "    y_label = y_label.transpose(1, 0)\n",
+    "    y2 = []\n",
+    "    y2label = []\n",
+    "    for i, elem in enumerate(y):\n",
+    "        j, k = convert_bin(elem, y_label[i], treatmentlabel2[i])\n",
+    "        y2.append(k)\n",
+    "        y2label.append(j)\n",
+    "    y2 = np.array(y2)\n",
+    "    y2label = np.array(y2label)\n",
+    "    y2label = np.expand_dims(y2label, -1)\n",
+    "\n",
+    "    return y2, y2label, treatmentlabel2, treatment2\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "\n",
+    "file_config = {\n",
+    "       'data':  'test.parquet',\n",
+    "}\n",
+    "optim_config = {\n",
+    "    'lr': 1e-4,\n",
+    "    'warmup_proportion': 0.1\n",
+    "}\n",
+    "\n",
+    "\n",
+    "BertVocab = {}\n",
+    "token2idx = {'MASK': 4,\n",
+    "  'CLS': 3,\n",
+    "  'SEP': 2,\n",
+    "  'UNK': 1,\n",
+    "  'PAD': 0,\n",
+    "            'disease1':5,\n",
+    "             'disease2':6,\n",
+    "             'disease3':7,\n",
+    "             'disease4':8,\n",
+    "             'disease5':9,\n",
+    "             'disease6':10,\n",
+    "             'medication1':11,\n",
+    "             'medication2':12,\n",
+    "             'medication3':13,\n",
+    "             'medication4':14,\n",
+    "             'medication5':15,\n",
+    "             'medication6':16,\n",
+    "            }\n",
+    "idx2token = {}\n",
+    "for x in token2idx:\n",
+    "    idx2token[token2idx[x]]=x\n",
+    "BertVocab['token2idx']= token2idx\n",
+    "BertVocab['idx2token']= idx2token\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "YearVocab = {'token2idx': {'PAD': 0,\n",
+    "  '1987': 1,\n",
+    "  '1988': 2,\n",
+    "  '1989': 3,\n",
+    "  '1990': 4,\n",
+    "  '1991': 5,\n",
+    "  '1992': 6,\n",
+    "  '1993': 7,\n",
+    "  '1994': 8,\n",
+    "  '1995': 9,\n",
+    "  '1996': 10,\n",
+    "  '1997': 11,\n",
+    "  '1998': 12,\n",
+    "  '1999': 13,\n",
+    "  '2000': 14,\n",
+    "  '2001': 15,\n",
+    "  '2002': 16,\n",
+    "  '2003': 17,\n",
+    "  '2004': 18,\n",
+    "  '2005': 19,\n",
+    "  '2006': 20,\n",
+    "  '2007': 21,\n",
+    "  '2008': 22,\n",
+    "  '2009': 23,\n",
+    "  '2010': 24,\n",
+    "  '2011': 25,\n",
+    "  '2012': 26,\n",
+    "  '2013': 27,\n",
+    "  '2014': 28,\n",
+    "  '2015': 29,\n",
+    "  'UNK': 30},\n",
+    " 'idx2token': {0: 'PAD',\n",
+    "  1: '1987',\n",
+    "  2: '1988',\n",
+    "  3: '1989',\n",
+    "  4: '1990',\n",
+    "  5: '1991',\n",
+    "  6: '1992',\n",
+    "  7: '1993',\n",
+    "  8: '1994',\n",
+    "  9: '1995',\n",
+    "  10: '1996',\n",
+    "  11: '1997',\n",
+    "  12: '1998',\n",
+    "  13: '1999',\n",
+    "  14: '2000',\n",
+    "  15: '2001',\n",
+    "  16: '2002',\n",
+    "  17: '2003',\n",
+    "  18: '2004',\n",
+    "  19: '2005',\n",
+    "  20: '2006',\n",
+    "  21: '2007',\n",
+    "  22: '2008',\n",
+    "  23: '2009',\n",
+    "  24: '2010',\n",
+    "  25: '2011',\n",
+    "  26: '2012',\n",
+    "  27: '2013',\n",
+    "  28: '2014',\n",
+    "  29: '2015',\n",
+    "  30: 'UNK'}}\n",
+    "\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "\n",
+    "global_params = {\n",
+    "    'batch_size': 128,\n",
+    "    'gradient_accumulation_steps': 1,\n",
+    "    'num_train_epochs': 3,\n",
+    "    'device': 'cuda:0',\n",
+    "    'output_dir': \"save_models\",\n",
+    "    'save_model': True,\n",
+    "    'max_len_seq': 250,\n",
+    "    'max_age': 110,\n",
+    "    'age_year': False,\n",
+    "    'age_symbol': None,\n",
+    "    'fac': 0.1,\n",
+    "    'diseaseI': 1,\n",
+    "    'treatments': 2\n",
+    "}\n",
+    "\n",
+    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], year=global_params['age_year'],\n",
+    "                        symbol=global_params['age_symbol'])\n",
+    "\n",
+    "model_config = {\n",
+    "    'vocab_size': len(BertVocab['token2idx'].keys()),  # number of disease + symbols for word embedding\n",
+    "    'hidden_size': 150,  # word embedding and seg embedding hidden size\n",
+    "    'seg_vocab_size': 2,  # number of vocab for seg embedding\n",
+    "    'age_vocab_size': len(ageVocab.keys()),  # number of vocab for age embedding\n",
+    "    'max_position_embedding': global_params['max_len_seq'],  # maximum number of tokens\n",
+    "    'hidden_dropout_prob': 0.3,  # dropout rate\n",
+    "    'num_hidden_layers': 4,  # number of multi-head attention layers required\n",
+    "    'num_attention_heads': 6,  # number of attention heads\n",
+    "    'attention_probs_dropout_prob': 0.4,  # multi-head attention dropout rate\n",
+    "    'intermediate_size': 108,  # the size of the \"intermediate\" layer in the transformer encoder\n",
+    "    'hidden_act': 'gelu',\n",
+    "    'initializer_range': 0.02,  # parameter weight initializer range\n",
+    "    'num_treatment': global_params['treatments'],\n",
+    "    'device': global_params['device'],\n",
+    "    'year_vocab_size': len(YearVocab['token2idx'].keys()),\n",
+    "\n",
+    "    'batch_size': global_params['batch_size'],\n",
+    "    'MEM': True,\n",
+    "    'poolingSize': 50,\n",
+    "    'unsupVAE': True,\n",
+    "    'unsupSize': ([[3,2]] *22) ,\n",
+    "    'vaelatentdim': 40,\n",
+    "    'vaehidden': 50,\n",
+    "    'vaeinchannels':39,\n",
+    "\n",
+    "\n",
+    "\n",
+    "}\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Begin experiments....\n",
+      "_________________\n",
+      "fold___0\n",
+      "_________________\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[('loss', tensor(6.4376, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(823.7451, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.0632, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 0| Loss: 0.04597\t| MLM: 0.13298\t| TOutP: 0.30261\t|vaeloss: 6.43762\t|ExpP: 0.76562\n",
+      "[('loss', tensor(6.0075, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(768.0154, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-17.6470, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 1| Loss: 0.03695\t| MLM: 0.35556\t| TOutP: 0.20014\t|vaeloss: 6.00750\t|ExpP: 0.63281\n",
+      "epoch: 0| Loss: 0.00941\t| MLM: 0.49010\t| TOutP: 0.26410\t|vaeloss: 3.98089\t|ExpP: 0.75781\n",
+      "[('loss', tensor(3.9809, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(450.0008, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-1108.7188, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.9582, device='cuda:0')\n",
+      "LossEval:  0.912086296081543\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.912086296081543\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "epoch: 1| Loss: 0.00508\t| MLM: 0.56186\t| TOutP: 0.17931\t|vaeloss: 1.55077\t|ExpP: 0.99219\n",
+      "[('loss', tensor(1.5508, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(56.6759, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2628.6050, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.8717, device='cuda:0')\n",
+      "LossEval:  0.75910964012146\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.75910964012146\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[[0], [1]]\n",
+      "0 tempLoss: 0.7803983688354492\n",
+      "LossEval:  0.6677856683731079 prec treat: 0.9766666666666667\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "_________________\n",
+      "fold___1\n",
+      "_________________\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[('loss', tensor(6.3858, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(817.0728, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.8979, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 0| Loss: 0.04601\t| MLM: 0.11060\t| TOutP: 0.15367\t|vaeloss: 6.38584\t|ExpP: 0.42188\n",
+      "[('loss', tensor(6.0483, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(773.4330, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-13.9553, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 1| Loss: 0.03693\t| MLM: 0.38764\t| TOutP: 0.28702\t|vaeloss: 6.04828\t|ExpP: 0.23438\n",
+      "epoch: 0| Loss: 0.00949\t| MLM: 0.48663\t| TOutP: 0.33637\t|vaeloss: 4.00320\t|ExpP: 0.39844\n",
+      "[('loss', tensor(4.0032, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(461.0717, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-955.7684, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.9395, device='cuda:0')\n",
+      "LossEval:  0.8925997734069824\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.8925997734069824\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "epoch: 1| Loss: 0.00531\t| MLM: 0.55825\t| TOutP: 0.24316\t|vaeloss: 1.45487\t|ExpP: 0.95312\n",
+      "[('loss', tensor(1.4549, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(58.0446, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2375.7451, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.8748, device='cuda:0')\n",
+      "LossEval:  0.7546854019165039\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.7546854019165039\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[[0], [1]]\n",
+      "0 tempLoss: 0.7930977940559387\n",
+      "LossEval:  0.6729061722755432 prec treat: 0.9766666666666667\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "_________________\n",
+      "fold___2\n",
+      "_________________\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[('loss', tensor(6.8586, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(877.5925, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.7485, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 0| Loss: 0.04843\t| MLM: 0.04945\t| TOutP: 0.24977\t|vaeloss: 6.85859\t|ExpP: 0.81250\n",
+      "[('loss', tensor(6.1659, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(788.5856, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-12.1737, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 1| Loss: 0.03748\t| MLM: 0.35294\t| TOutP: 0.24342\t|vaeloss: 6.16592\t|ExpP: 0.74219\n",
+      "epoch: 0| Loss: 0.00960\t| MLM: 0.51531\t| TOutP: 0.13057\t|vaeloss: 4.42672\t|ExpP: 0.83594\n",
+      "[('loss', tensor(4.4267, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(515.2249, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-956.8342, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.9165, device='cuda:0')\n",
+      "LossEval:  0.8629927635192871\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.8629927635192871\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "epoch: 1| Loss: 0.00494\t| MLM: 0.63473\t| TOutP: 0.20097\t|vaeloss: 1.60833\t|ExpP: 0.97656\n",
+      "[('loss', tensor(1.6083, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(76.9737, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2388.9741, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.8581, device='cuda:0')\n",
+      "LossEval:  0.7425572872161865\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.7425572872161865\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[[0], [1]]\n",
+      "0 tempLoss: 0.7719957232475281\n",
+      "LossEval:  0.6564249277114869 prec treat: 0.9766666666666667\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "_________________\n",
+      "fold___3\n",
+      "_________________\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[('loss', tensor(6.4421, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(824.2659, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-6.0783, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 0| Loss: 0.04659\t| MLM: 0.05825\t| TOutP: 0.23844\t|vaeloss: 6.44211\t|ExpP: 0.56250\n",
+      "[('loss', tensor(6.5312, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(835.7316, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-4.9471, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 1| Loss: 0.03936\t| MLM: 0.37313\t| TOutP: 0.25152\t|vaeloss: 6.53122\t|ExpP: 0.38281\n",
+      "epoch: 0| Loss: 0.00999\t| MLM: 0.40314\t| TOutP: 0.28174\t|vaeloss: 4.91041\t|ExpP: 0.42969\n",
+      "[('loss', tensor(4.9104, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(588.9692, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-736.5710, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.9352, device='cuda:0')\n",
+      "LossEval:  0.888219165802002\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.888219165802002\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "epoch: 1| Loss: 0.00542\t| MLM: 0.46328\t| TOutP: 0.16435\t|vaeloss: 1.87306\t|ExpP: 0.98438\n",
+      "[('loss', tensor(1.8731, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(115.1043, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2310.2729, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.8745, device='cuda:0')\n",
+      "LossEval:  0.759437370300293\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.759437370300293\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[[0], [1]]\n",
+      "0 tempLoss: 0.7776303291320801\n",
+      "LossEval:  0.6625770092010498 prec treat: 0.9766666666666667\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "_________________\n",
+      "fold___4\n",
+      "_________________\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[('loss', tensor(7.1006, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(908.5834, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.5566, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 0| Loss: 0.04973\t| MLM: 0.08920\t| TOutP: 0.34168\t|vaeloss: 7.10062\t|ExpP: 0.75000\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[('loss', tensor(6.3710, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(815.0780, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-7.6161, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "epoch: 1| Loss: 0.03872\t| MLM: 0.42614\t| TOutP: 0.17729\t|vaeloss: 6.37098\t|ExpP: 0.48438\n",
+      "epoch: 0| Loss: 0.00975\t| MLM: 0.39409\t| TOutP: 0.24371\t|vaeloss: 4.48545\t|ExpP: 0.45312\n",
+      "[('loss', tensor(4.4854, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(531.1162, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-800.9330, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.9520, device='cuda:0')\n",
+      "LossEval:  0.9080421447753906\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.9080421447753906\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "epoch: 1| Loss: 0.00529\t| MLM: 0.64242\t| TOutP: 0.28668\t|vaeloss: 1.59132\t|ExpP: 0.99219\n",
+      "[('loss', tensor(1.5913, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(74.8522, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2387.9209, device='cuda:0', grad_fn=<NegBackward>))]\n",
+      "0 tempLoss: tensor(0.8671, device='cuda:0')\n",
+      "LossEval:  0.7486227989196778\n",
+      "** ** * Saving best fine - tuned model ** ** * \n",
+      "auc-mean:  -0.7486227989196778\n",
+      "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
+      "turning on the MEM....\n",
+      "full init completed...\n",
+      "[[0], [1]]\n",
+      "0 tempLoss: 0.7893564105033875\n",
+      "LossEval:  0.670855188369751 prec treat: 0.9766666666666667\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "\n",
+    "data = pd.read_parquet (file_config['data'])\n",
+    "\n",
+    "kf = KFold(n_splits = 5, shuffle = True, random_state = 2)\n",
+    "\n",
+    "print('Begin experiments....')\n",
+    "\n",
+    "\n",
+    "\n",
+    "for cutiter in (range(5)):\n",
+    "    print(\"_________________\\nfold___\" + str(cutiter) + \"\\n_________________\")\n",
+    "    data = pd.read_parquet (file_config['data'])\n",
+    "\n",
+    "    result = next(kf.split(data), None)\n",
+    "\n",
+    "    datatrain = data.iloc[result[0]].reset_index(drop=True)\n",
+    "    testdata =  data.iloc[result[1]].reset_index(drop=True)\n",
+    "\n",
+    "    tset = TBEHRT_data_formation(BertVocab['token2idx'], testdata, code= 'code', \n",
+    "                                 age = 'age', year = 'year' , static= 'static' , \n",
+    "                                 max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label',  \n",
+    "                                 yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n",
+    "    \n",
+    "        \n",
+    "   \n",
+    "    testload = DataLoader(dataset=tset, batch_size=int(global_params['batch_size']), shuffle=False, num_workers=0)\n",
+    "\n",
+    "\n",
+    "    model_config['klpar']= float(1.0/(len(datatrain)/global_params['batch_size']))\n",
+    "    conf = BertConfig(model_config)\n",
+    "    model = TBEHRT(conf, 1)\n",
+    "\n",
+    "    optim = optimizer.adam(params=list(model.named_parameters()), config=optim_config)\n",
+    "\n",
+    "    model_to_save_name =  'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".bin\"\n",
+    "\n",
+    "    import warnings\n",
+    "\n",
+    "    warnings.filterwarnings(action='ignore')\n",
+    "    scheduler = toptimizer.lr_scheduler.ExponentialLR(optim, 0.95, last_epoch=-1)\n",
+    "    patience = 0\n",
+    "    best_pre = -100000000000000000000\n",
+    "    LossC = 0.1\n",
+    "    #\n",
+    "    for e in range(2):\n",
+    "        scheduler , patience= trainunsup(e, scheduler, patience)\n",
+    "\n",
+    "    for e in range(2):\n",
+    "        train_multi(e)\n",
+    "        auc, auroc, auc2, loss = evaluation_multi_repeats()\n",
+    "        aucreal = -1 * loss\n",
+    "        if aucreal > best_pre:\n",
+    "            patience = 0\n",
+    "            # Save a trained model\n",
+    "            print(\"** ** * Saving best fine - tuned model ** ** * \")\n",
+    "            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
+    "            output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n",
+    "            create_folder(global_params['output_dir'])\n",
+    "            if global_params['save_model']:\n",
+    "                torch.save(model_to_save.state_dict(), output_model_file)\n",
+    "\n",
+    "            best_pre = aucreal\n",
+    "            print(\"auc-mean: \", aucreal)\n",
+    "        else:\n",
+    "            if patience % 2 == 0 and patience != 0:\n",
+    "                scheduler.step()\n",
+    "                print(\"LR: \", scheduler.get_lr())\n",
+    "\n",
+    "            patience = patience + 1\n",
+    "        print('auprc : {}, auroc : {}, Treat-auc : {}, time: {}'.format(auc, auroc, auc2, \"long.....\"))\n",
+    "\n",
+    "\n",
+    "\n",
+    "    LossC = 0.1\n",
+    "    conf = BertConfig(model_config)\n",
+    "    model = TBEHRT(conf, 1)\n",
+    "    optim = optimizer.VAEadam(params=list(model.named_parameters()), config=optim_config)\n",
+    "    output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n",
+    "    model = toLoad(model, output_model_file)\n",
+    "\n",
+    "\n",
+    "    y, y_label, t_label, treatO, tprc, eps = fullEval_4analysis_multi(False, True, testdata)\n",
+    "\n",
+    "    y2, y2label, treatmentlabel2, treatment2 = fullCONV(y, y_label, t_label, treatO)\n",
+    "\n",
+    "    NPSaveNAME =  'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".npz\"\n",
+    "\n",
+    "    np.savez(  NPSaveNAME,\n",
+    "             outcome=y2,\n",
+    "             outcome_label=y2label, treatment=treatment2, treatment_label=treatmentlabel2,\n",
+    "             epsilon=np.array([0]))\n",
+    "    del y, y_label, t_label, treatO, tprc, eps, y2, y2label, treatmentlabel2, treatment2, datatrain, conf, model, optim, output_model_file,  best_pre, LossC,\n",
+    "    print(\"\\n\\n\\n\\n\\n\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "real3",
+   "language": "python",
+   "name": "py3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.8"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}