932 lines (931 with data), 40.0 kB
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\n",
"import sys\n",
"sys.path.insert(0, '/home/rnshishir/deepmed/TBEHRT_pl/')\n",
"\n",
"import os\n",
"from torch.utils.data import DataLoader\n",
"from sklearn.model_selection import KFold\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pytorch_pretrained_bert as Bert\n",
"\n",
"from pytorch_pretrained_bert import optimizer\n",
"import sklearn.metrics as skm\n",
"from torch.utils.data.dataset import Dataset\n",
"from src.utils import *\n",
"from src.model import *\n",
"from src.data import *\n",
"\n",
"from torch import optim as toptimizer\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def get_beta(batch_idx, m, beta_type):\n",
" if beta_type == \"Blundell\":\n",
" beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)\n",
" elif beta_type == \"Soenderby\":\n",
" beta = min(epoch / (num_epochs // 4), 1)\n",
" elif beta_type == \"Standard\":\n",
" beta = 1 / m\n",
" else:\n",
" beta = 0\n",
" return beta\n",
"\n",
"\n",
"def trainunsup(e, sched, patienceMetric, MEM=True):\n",
" sampled = datatrain.reset_index(drop=True)\n",
" #\n",
"\n",
" Dset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
" age = 'age', year = 'year' , static= 'static' , \n",
" max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n",
" yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n",
" \n",
" \n",
" \n",
" trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n",
" sampler=None)\n",
"\n",
" model.train()\n",
" tr_loss = 0\n",
" temp_loss = 0\n",
" nb_tr_examples, nb_tr_steps = 0, 0\n",
" oldloss = 10 ** 10\n",
" for step, batch in enumerate(trainload):\n",
"\n",
" batch = tuple(t.to(global_params['device']) for t in batch)\n",
"\n",
" age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
"\n",
" masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
" input_idsMLM,\n",
" age_ids,\n",
" segment_ids,\n",
" posi_ids,\n",
" year_ids,\n",
"\n",
" attention_mask=attMask,\n",
" masked_lm_labels=masked_label,\n",
" outcomeT=outcome_label,\n",
" treatmentCLabel=treatment_label,\n",
" fullEval=False,\n",
" vaelabel=vaelabel)\n",
" vaeloss = vaelosspure['loss']\n",
"\n",
" totalL = masked_lm_loss\n",
" if global_params['gradient_accumulation_steps'] > 1:\n",
" totalL = totalL / global_params['gradient_accumulation_steps']\n",
" totalL.backward()\n",
" treatFull = treatOut\n",
" treatLabelFull = treatLabel\n",
" treatLabelFull = treatLabelFull.cpu().detach()\n",
"\n",
" outFull = out\n",
"\n",
" outLabelFull = outLabel\n",
" treatindex = treatindex.cpu().detach().numpy()\n",
" zeroind = np.where(treatindex == 0)\n",
" outzero = outFull[0][zeroind]\n",
" outzeroLabel = outLabelFull[zeroind]\n",
"\n",
"\n",
" temp_loss += totalL.item()\n",
" tr_loss += totalL.item()\n",
" nb_tr_examples += input_ids.size(0)\n",
" nb_tr_steps += 1\n",
"\n",
" if step % 600 == 0:\n",
" print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n",
" keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n",
" if oldloss < vaelosspure['loss']:\n",
" patienceMetric = patienceMetric + 1\n",
" if patienceMetric >= 10:\n",
" sched.step()\n",
" print(\"LR: \", sched.get_lr())\n",
" patienceMetric = 0\n",
" oldloss = vaelosspure['loss']\n",
"\n",
" if step % 200 == 0:\n",
" precOut0 = -1\n",
" if len(zeroind[0]) > 0:\n",
" precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n",
"\n",
" print(\n",
" \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n",
" e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n",
" cal_acc(treatLabelFull, treatFull, False)))\n",
" temp_loss = 0\n",
"\n",
" if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
" optim.step()\n",
" optim.zero_grad()\n",
"\n",
" # Save a trained model\n",
" del sampled, Dset, trainload\n",
" return sched, patienceMetric\n",
"\n",
"\n",
"def train_multi(e, MEM=True):\n",
" sampled = datatrain.reset_index(drop=True)\n",
"\n",
" Dset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
" age = 'age', year = 'year' , static= 'static' , \n",
" max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n",
" yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n",
" \n",
" \n",
" \n",
" trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n",
" sampler=None)\n",
" \n",
" \n",
" model.train()\n",
" tr_loss = 0\n",
" temp_loss = 0\n",
" nb_tr_examples, nb_tr_steps = 0, 0\n",
" for step, batch in enumerate(trainload):\n",
"\n",
" batch = tuple(t.to(global_params['device']) for t in batch)\n",
"\n",
" age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
" masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
" input_idsMLM,\n",
" age_ids,\n",
" segment_ids,\n",
" posi_ids,\n",
" year_ids,\n",
"\n",
" attention_mask=attMask,\n",
" masked_lm_labels=masked_label,\n",
" outcomeT=outcome_label,\n",
" treatmentCLabel=treatment_label,\n",
" fullEval=False,\n",
" vaelabel=vaelabel)\n",
"\n",
" vaeloss = vaelosspure['loss']\n",
" totalL = 1 * (lossT) + 0 + (global_params['fac'] * masked_lm_loss)\n",
" if global_params['gradient_accumulation_steps'] > 1:\n",
" totalL = totalL / global_params['gradient_accumulation_steps']\n",
" totalL.backward()\n",
" treatFull = treatOut\n",
" treatLabelFull = treatLabel\n",
" treatLabelFull = treatLabelFull.cpu().detach()\n",
"\n",
" outFull = out\n",
"\n",
" outLabelFull = outLabel\n",
" treatindex = treatindex.cpu().detach().numpy()\n",
" zeroind = np.where(treatindex == 0)\n",
" outzero = outFull[0][zeroind]\n",
" outzeroLabel = outLabelFull[zeroind]\n",
"\n",
" temp_loss += totalL.item()\n",
" tr_loss += totalL.item()\n",
" nb_tr_examples += input_ids.size(0)\n",
" nb_tr_steps += 1\n",
"\n",
" if step % 200 == 0:\n",
" precOut0 = -1\n",
"\n",
" if len(zeroind[0]) > 0:\n",
" precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n",
"\n",
" print(\n",
" \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n",
" e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n",
" cal_acc(treatLabelFull, treatFull, False)))\n",
"\n",
" print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n",
" keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n",
" temp_loss = 0\n",
"\n",
" if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
" optim.step()\n",
" optim.zero_grad()\n",
"\n",
" del sampled, Dset, trainload\n",
"\n",
"\n",
"def evaluation_multi_repeats():\n",
" model.eval()\n",
" y = []\n",
" y_label = []\n",
" t_label = []\n",
" t_output = []\n",
" count = 0\n",
" totalL = 0\n",
" for step, batch in enumerate(testload):\n",
" model.eval()\n",
" count = count + 1\n",
" batch = tuple(t.to(global_params['device']) for t in batch)\n",
"\n",
" age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
" with torch.no_grad():\n",
"\n",
" masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n",
" input_idsMLM,\n",
" age_ids,\n",
" segment_ids,\n",
" posi_ids,\n",
" year_ids,\n",
"\n",
" attention_mask=attMask,\n",
" masked_lm_labels=masked_label,\n",
" outcomeT=outcome_label,\n",
" treatmentCLabel=treatment_label, vaelabel=vaelabel)\n",
"\n",
" totalL = totalL + lossT.item() + 0 + (global_params['fac'] * masked_lm_loss)\n",
" treatFull = treatOut\n",
" treatLabelFull = treatLabel\n",
" treatLabelFull = treatLabelFull.detach()\n",
" outFull = out\n",
" outLabelFull = outLabel\n",
" treatindex = treatindex.cpu().detach().numpy()\n",
" outPred = []\n",
" outexpLab = []\n",
" for el in range(global_params['treatments']):\n",
" zeroind = np.where(treatindex == el)\n",
" outPred.append(outFull[el][zeroind])\n",
" outexpLab.append(outLabelFull[zeroind])\n",
"\n",
"\n",
" y_label.append(torch.cat(outexpLab))\n",
"\n",
" y.append(torch.cat(outPred))\n",
"\n",
" treatOut = treatFull.cpu()\n",
" treatLabel = treatLabelFull.cpu()\n",
" if step % 200 == 0:\n",
" print(step, \"tempLoss:\", totalL / count)\n",
"\n",
" t_label.append(treatLabel)\n",
" t_output.append(treatOut)\n",
"\n",
" y_label = torch.cat(y_label, dim=0)\n",
" y = torch.cat(y, dim=0)\n",
" t_label = torch.cat(t_label, dim=0)\n",
" treatO = torch.cat(t_output, dim=0)\n",
"\n",
" tempprc, output, label = precision_test(y, y_label, False)\n",
" treatPRC = cal_acc(t_label, treatO, False)\n",
" tempprc2, output2, label2 = roc_auc(y, y_label, False)\n",
"\n",
" print(\"LossEval: \", float(totalL) / float(count))\n",
"\n",
" return tempprc, tempprc2, treatPRC, float(totalL) / float(count)\n",
"\n",
"\n",
"def fullEval_4analysis_multi(tr, te, filetest):\n",
" if tr:\n",
" sampled = datatrain.reset_index(drop=True)\n",
"\n",
" if te:\n",
" data = filetest\n",
"\n",
" if tr:\n",
" sampled = pd.concat([sampled, data]).reset_index(drop=True)\n",
" else:\n",
" sampled = data\n",
" Fulltset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n",
" age = 'age', year = 'year' , static= 'static' , \n",
" max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n",
" yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n",
" \n",
" \n",
" \n",
" fullDataLoad = DataLoader(dataset=Fulltset, batch_size=int(global_params['batch_size']), shuffle=False,\n",
" num_workers=0)\n",
"\n",
" model.eval()\n",
" y = []\n",
" y_label = []\n",
" t_label = []\n",
" t_output = []\n",
" count = 0\n",
" totalL = 0\n",
" eps_array = []\n",
"\n",
" for yyy in range(model_config['num_treatment']):\n",
" y.append([yyy])\n",
" y_label.append([yyy])\n",
"\n",
" print(y)\n",
" for step, batch in enumerate(fullDataLoad):\n",
" model.eval()\n",
"\n",
" count = count + 1\n",
" batch = tuple(t.to(global_params['device']) for t in batch)\n",
"\n",
" age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n",
"\n",
" with torch.no_grad():\n",
" masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaeloss = model(\n",
" input_idsMLM,\n",
" age_ids,\n",
" segment_ids,\n",
" posi_ids,\n",
" year_ids,\n",
"\n",
" attention_mask=attMask,\n",
" masked_lm_labels=masked_label,\n",
" outcomeT=outcome_label,\n",
" treatmentCLabel=treatment_label, fullEval=True, vaelabel=vaelabel)\n",
"\n",
"\n",
"\n",
" outFull = out\n",
" outLabelFull = outLabel\n",
"\n",
"\n",
" for el in range(global_params['treatments']):\n",
" y[el].append(outFull[el].cpu())\n",
" y_label[el].append(outLabelFull.cpu())\n",
"\n",
" totalL = totalL + (1 * (lossT)).item()\n",
"\n",
" if step % 200 == 0:\n",
" print(step, \"tempLoss:\", totalL / count)\n",
"\n",
" t_label.append(treatLabel)\n",
" t_output.append(treatOut)\n",
"\n",
" for idd, elem in enumerate(y):\n",
" elem = torch.cat(elem[1:], dim=0)\n",
" y[idd] = elem\n",
" for idd, elem in enumerate(y_label):\n",
" elem = torch.cat(elem[1:], dim=0)\n",
" y_label[idd] = elem\n",
"\n",
" t_label = torch.cat(t_label, dim=0)\n",
" treatO = torch.cat(t_output, dim=0)\n",
" treatPRC = cal_acc(t_label, treatO)\n",
"\n",
" print(\"LossEval: \", float(totalL) / float(count), \"prec treat:\", treatPRC)\n",
" return y, y_label, t_label, treatO, treatPRC, eps_array\n",
"\n",
"\n",
"def fullCONV(y, y_label, t_label, treatO):\n",
" def convert_multihot(label, pred):\n",
" label = label.cpu().numpy()\n",
" truepred = pred.detach().cpu().numpy()\n",
" truelabel = label\n",
" newpred = []\n",
" for i, x in enumerate(truelabel):\n",
" temppred = []\n",
" temppred.append(truepred[i][0])\n",
" temppred.append(truepred[i][x[0]])\n",
" newpred.append(temppred)\n",
" return truelabel, np.array(truepred)\n",
"\n",
" def convert_bin(logits, label, treatmentlabel2):\n",
"\n",
" output = logits\n",
" label, output = label.cpu().numpy(), output.detach().cpu().numpy()\n",
" label = label[treatmentlabel2[0]]\n",
"\n",
" return label, output\n",
"\n",
" treatmentlabel2, treatment2 = convert_multihot(t_label, treatO)\n",
" y = torch.cat(y, dim=0).view(global_params['treatments'], -1)\n",
" y = y.transpose(1, 0)\n",
" y_label = torch.cat(y_label, dim=0).view(global_params['treatments'], -1)\n",
" y_label = y_label.transpose(1, 0)\n",
" y2 = []\n",
" y2label = []\n",
" for i, elem in enumerate(y):\n",
" j, k = convert_bin(elem, y_label[i], treatmentlabel2[i])\n",
" y2.append(k)\n",
" y2label.append(j)\n",
" y2 = np.array(y2)\n",
" y2label = np.array(y2label)\n",
" y2label = np.expand_dims(y2label, -1)\n",
"\n",
" return y2, y2label, treatmentlabel2, treatment2\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"file_config = {\n",
" 'data': 'test.parquet',\n",
"}\n",
"optim_config = {\n",
" 'lr': 1e-4,\n",
" 'warmup_proportion': 0.1\n",
"}\n",
"\n",
"\n",
"BertVocab = {}\n",
"token2idx = {'MASK': 4,\n",
" 'CLS': 3,\n",
" 'SEP': 2,\n",
" 'UNK': 1,\n",
" 'PAD': 0,\n",
" 'disease1':5,\n",
" 'disease2':6,\n",
" 'disease3':7,\n",
" 'disease4':8,\n",
" 'disease5':9,\n",
" 'disease6':10,\n",
" 'medication1':11,\n",
" 'medication2':12,\n",
" 'medication3':13,\n",
" 'medication4':14,\n",
" 'medication5':15,\n",
" 'medication6':16,\n",
" }\n",
"idx2token = {}\n",
"for x in token2idx:\n",
" idx2token[token2idx[x]]=x\n",
"BertVocab['token2idx']= token2idx\n",
"BertVocab['idx2token']= idx2token\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"YearVocab = {'token2idx': {'PAD': 0,\n",
" '1987': 1,\n",
" '1988': 2,\n",
" '1989': 3,\n",
" '1990': 4,\n",
" '1991': 5,\n",
" '1992': 6,\n",
" '1993': 7,\n",
" '1994': 8,\n",
" '1995': 9,\n",
" '1996': 10,\n",
" '1997': 11,\n",
" '1998': 12,\n",
" '1999': 13,\n",
" '2000': 14,\n",
" '2001': 15,\n",
" '2002': 16,\n",
" '2003': 17,\n",
" '2004': 18,\n",
" '2005': 19,\n",
" '2006': 20,\n",
" '2007': 21,\n",
" '2008': 22,\n",
" '2009': 23,\n",
" '2010': 24,\n",
" '2011': 25,\n",
" '2012': 26,\n",
" '2013': 27,\n",
" '2014': 28,\n",
" '2015': 29,\n",
" 'UNK': 30},\n",
" 'idx2token': {0: 'PAD',\n",
" 1: '1987',\n",
" 2: '1988',\n",
" 3: '1989',\n",
" 4: '1990',\n",
" 5: '1991',\n",
" 6: '1992',\n",
" 7: '1993',\n",
" 8: '1994',\n",
" 9: '1995',\n",
" 10: '1996',\n",
" 11: '1997',\n",
" 12: '1998',\n",
" 13: '1999',\n",
" 14: '2000',\n",
" 15: '2001',\n",
" 16: '2002',\n",
" 17: '2003',\n",
" 18: '2004',\n",
" 19: '2005',\n",
" 20: '2006',\n",
" 21: '2007',\n",
" 22: '2008',\n",
" 23: '2009',\n",
" 24: '2010',\n",
" 25: '2011',\n",
" 26: '2012',\n",
" 27: '2013',\n",
" 28: '2014',\n",
" 29: '2015',\n",
" 30: 'UNK'}}\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"global_params = {\n",
" 'batch_size': 128,\n",
" 'gradient_accumulation_steps': 1,\n",
" 'num_train_epochs': 3,\n",
" 'device': 'cuda:0',\n",
" 'output_dir': \"save_models\",\n",
" 'save_model': True,\n",
" 'max_len_seq': 250,\n",
" 'max_age': 110,\n",
" 'age_year': False,\n",
" 'age_symbol': None,\n",
" 'fac': 0.1,\n",
" 'diseaseI': 1,\n",
" 'treatments': 2\n",
"}\n",
"\n",
"ageVocab, _ = age_vocab(max_age=global_params['max_age'], year=global_params['age_year'],\n",
" symbol=global_params['age_symbol'])\n",
"\n",
"model_config = {\n",
" 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
" 'hidden_size': 150, # word embedding and seg embedding hidden size\n",
" 'seg_vocab_size': 2, # number of vocab for seg embedding\n",
" 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
" 'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
" 'hidden_dropout_prob': 0.3, # dropout rate\n",
" 'num_hidden_layers': 4, # number of multi-head attention layers required\n",
" 'num_attention_heads': 6, # number of attention heads\n",
" 'attention_probs_dropout_prob': 0.4, # multi-head attention dropout rate\n",
" 'intermediate_size': 108, # the size of the \"intermediate\" layer in the transformer encoder\n",
" 'hidden_act': 'gelu',\n",
" 'initializer_range': 0.02, # parameter weight initializer range\n",
" 'num_treatment': global_params['treatments'],\n",
" 'device': global_params['device'],\n",
" 'year_vocab_size': len(YearVocab['token2idx'].keys()),\n",
"\n",
" 'batch_size': global_params['batch_size'],\n",
" 'MEM': True,\n",
" 'poolingSize': 50,\n",
" 'unsupVAE': True,\n",
" 'unsupSize': ([[3,2]] *22) ,\n",
" 'vaelatentdim': 40,\n",
" 'vaehidden': 50,\n",
" 'vaeinchannels':39,\n",
"\n",
"\n",
"\n",
"}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Begin experiments....\n",
"_________________\n",
"fold___0\n",
"_________________\n",
"turning on the MEM....\n",
"full init completed...\n",
"[('loss', tensor(6.4376, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(823.7451, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.0632, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 0| Loss: 0.04597\t| MLM: 0.13298\t| TOutP: 0.30261\t|vaeloss: 6.43762\t|ExpP: 0.76562\n",
"[('loss', tensor(6.0075, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(768.0154, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-17.6470, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 1| Loss: 0.03695\t| MLM: 0.35556\t| TOutP: 0.20014\t|vaeloss: 6.00750\t|ExpP: 0.63281\n",
"epoch: 0| Loss: 0.00941\t| MLM: 0.49010\t| TOutP: 0.26410\t|vaeloss: 3.98089\t|ExpP: 0.75781\n",
"[('loss', tensor(3.9809, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(450.0008, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-1108.7188, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.9582, device='cuda:0')\n",
"LossEval: 0.912086296081543\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.912086296081543\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"epoch: 1| Loss: 0.00508\t| MLM: 0.56186\t| TOutP: 0.17931\t|vaeloss: 1.55077\t|ExpP: 0.99219\n",
"[('loss', tensor(1.5508, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(56.6759, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2628.6050, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.8717, device='cuda:0')\n",
"LossEval: 0.75910964012146\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.75910964012146\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"turning on the MEM....\n",
"full init completed...\n",
"[[0], [1]]\n",
"0 tempLoss: 0.7803983688354492\n",
"LossEval: 0.6677856683731079 prec treat: 0.9766666666666667\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"_________________\n",
"fold___1\n",
"_________________\n",
"turning on the MEM....\n",
"full init completed...\n",
"[('loss', tensor(6.3858, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(817.0728, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.8979, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 0| Loss: 0.04601\t| MLM: 0.11060\t| TOutP: 0.15367\t|vaeloss: 6.38584\t|ExpP: 0.42188\n",
"[('loss', tensor(6.0483, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(773.4330, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-13.9553, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 1| Loss: 0.03693\t| MLM: 0.38764\t| TOutP: 0.28702\t|vaeloss: 6.04828\t|ExpP: 0.23438\n",
"epoch: 0| Loss: 0.00949\t| MLM: 0.48663\t| TOutP: 0.33637\t|vaeloss: 4.00320\t|ExpP: 0.39844\n",
"[('loss', tensor(4.0032, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(461.0717, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-955.7684, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.9395, device='cuda:0')\n",
"LossEval: 0.8925997734069824\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.8925997734069824\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"epoch: 1| Loss: 0.00531\t| MLM: 0.55825\t| TOutP: 0.24316\t|vaeloss: 1.45487\t|ExpP: 0.95312\n",
"[('loss', tensor(1.4549, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(58.0446, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2375.7451, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.8748, device='cuda:0')\n",
"LossEval: 0.7546854019165039\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.7546854019165039\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"turning on the MEM....\n",
"full init completed...\n",
"[[0], [1]]\n",
"0 tempLoss: 0.7930977940559387\n",
"LossEval: 0.6729061722755432 prec treat: 0.9766666666666667\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"_________________\n",
"fold___2\n",
"_________________\n",
"turning on the MEM....\n",
"full init completed...\n",
"[('loss', tensor(6.8586, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(877.5925, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.7485, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 0| Loss: 0.04843\t| MLM: 0.04945\t| TOutP: 0.24977\t|vaeloss: 6.85859\t|ExpP: 0.81250\n",
"[('loss', tensor(6.1659, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(788.5856, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-12.1737, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 1| Loss: 0.03748\t| MLM: 0.35294\t| TOutP: 0.24342\t|vaeloss: 6.16592\t|ExpP: 0.74219\n",
"epoch: 0| Loss: 0.00960\t| MLM: 0.51531\t| TOutP: 0.13057\t|vaeloss: 4.42672\t|ExpP: 0.83594\n",
"[('loss', tensor(4.4267, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(515.2249, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-956.8342, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.9165, device='cuda:0')\n",
"LossEval: 0.8629927635192871\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.8629927635192871\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"epoch: 1| Loss: 0.00494\t| MLM: 0.63473\t| TOutP: 0.20097\t|vaeloss: 1.60833\t|ExpP: 0.97656\n",
"[('loss', tensor(1.6083, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(76.9737, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2388.9741, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.8581, device='cuda:0')\n",
"LossEval: 0.7425572872161865\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.7425572872161865\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"turning on the MEM....\n",
"full init completed...\n",
"[[0], [1]]\n",
"0 tempLoss: 0.7719957232475281\n",
"LossEval: 0.6564249277114869 prec treat: 0.9766666666666667\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"_________________\n",
"fold___3\n",
"_________________\n",
"turning on the MEM....\n",
"full init completed...\n",
"[('loss', tensor(6.4421, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(824.2659, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-6.0783, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 0| Loss: 0.04659\t| MLM: 0.05825\t| TOutP: 0.23844\t|vaeloss: 6.44211\t|ExpP: 0.56250\n",
"[('loss', tensor(6.5312, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(835.7316, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-4.9471, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 1| Loss: 0.03936\t| MLM: 0.37313\t| TOutP: 0.25152\t|vaeloss: 6.53122\t|ExpP: 0.38281\n",
"epoch: 0| Loss: 0.00999\t| MLM: 0.40314\t| TOutP: 0.28174\t|vaeloss: 4.91041\t|ExpP: 0.42969\n",
"[('loss', tensor(4.9104, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(588.9692, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-736.5710, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.9352, device='cuda:0')\n",
"LossEval: 0.888219165802002\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.888219165802002\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"epoch: 1| Loss: 0.00542\t| MLM: 0.46328\t| TOutP: 0.16435\t|vaeloss: 1.87306\t|ExpP: 0.98438\n",
"[('loss', tensor(1.8731, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(115.1043, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2310.2729, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.8745, device='cuda:0')\n",
"LossEval: 0.759437370300293\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.759437370300293\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"turning on the MEM....\n",
"full init completed...\n",
"[[0], [1]]\n",
"0 tempLoss: 0.7776303291320801\n",
"LossEval: 0.6625770092010498 prec treat: 0.9766666666666667\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"_________________\n",
"fold___4\n",
"_________________\n",
"turning on the MEM....\n",
"full init completed...\n",
"[('loss', tensor(7.1006, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(908.5834, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.5566, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 0| Loss: 0.04973\t| MLM: 0.08920\t| TOutP: 0.34168\t|vaeloss: 7.10062\t|ExpP: 0.75000\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('loss', tensor(6.3710, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(815.0780, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-7.6161, device='cuda:0', grad_fn=<NegBackward>))]\n",
"epoch: 1| Loss: 0.03872\t| MLM: 0.42614\t| TOutP: 0.17729\t|vaeloss: 6.37098\t|ExpP: 0.48438\n",
"epoch: 0| Loss: 0.00975\t| MLM: 0.39409\t| TOutP: 0.24371\t|vaeloss: 4.48545\t|ExpP: 0.45312\n",
"[('loss', tensor(4.4854, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(531.1162, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-800.9330, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.9520, device='cuda:0')\n",
"LossEval: 0.9080421447753906\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.9080421447753906\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"epoch: 1| Loss: 0.00529\t| MLM: 0.64242\t| TOutP: 0.28668\t|vaeloss: 1.59132\t|ExpP: 0.99219\n",
"[('loss', tensor(1.5913, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(74.8522, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2387.9209, device='cuda:0', grad_fn=<NegBackward>))]\n",
"0 tempLoss: tensor(0.8671, device='cuda:0')\n",
"LossEval: 0.7486227989196778\n",
"** ** * Saving best fine - tuned model ** ** * \n",
"auc-mean: -0.7486227989196778\n",
"auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n",
"turning on the MEM....\n",
"full init completed...\n",
"[[0], [1]]\n",
"0 tempLoss: 0.7893564105033875\n",
"LossEval: 0.670855188369751 prec treat: 0.9766666666666667\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"source": [
"\n",
"data = pd.read_parquet (file_config['data'])\n",
"\n",
"kf = KFold(n_splits = 5, shuffle = True, random_state = 2)\n",
"\n",
"print('Begin experiments....')\n",
"\n",
"\n",
"\n",
"for cutiter in (range(5)):\n",
" print(\"_________________\\nfold___\" + str(cutiter) + \"\\n_________________\")\n",
" data = pd.read_parquet (file_config['data'])\n",
"\n",
" result = next(kf.split(data), None)\n",
"\n",
" datatrain = data.iloc[result[0]].reset_index(drop=True)\n",
" testdata = data.iloc[result[1]].reset_index(drop=True)\n",
"\n",
" tset = TBEHRT_data_formation(BertVocab['token2idx'], testdata, code= 'code', \n",
" age = 'age', year = 'year' , static= 'static' , \n",
" max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n",
" yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n",
" \n",
" \n",
" \n",
" testload = DataLoader(dataset=tset, batch_size=int(global_params['batch_size']), shuffle=False, num_workers=0)\n",
"\n",
"\n",
" model_config['klpar']= float(1.0/(len(datatrain)/global_params['batch_size']))\n",
" conf = BertConfig(model_config)\n",
" model = TBEHRT(conf, 1)\n",
"\n",
" optim = optimizer.adam(params=list(model.named_parameters()), config=optim_config)\n",
"\n",
" model_to_save_name = 'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".bin\"\n",
"\n",
" import warnings\n",
"\n",
" warnings.filterwarnings(action='ignore')\n",
" scheduler = toptimizer.lr_scheduler.ExponentialLR(optim, 0.95, last_epoch=-1)\n",
" patience = 0\n",
" best_pre = -100000000000000000000\n",
" LossC = 0.1\n",
" #\n",
" for e in range(2):\n",
" scheduler , patience= trainunsup(e, scheduler, patience)\n",
"\n",
" for e in range(2):\n",
" train_multi(e)\n",
" auc, auroc, auc2, loss = evaluation_multi_repeats()\n",
" aucreal = -1 * loss\n",
" if aucreal > best_pre:\n",
" patience = 0\n",
" # Save a trained model\n",
" print(\"** ** * Saving best fine - tuned model ** ** * \")\n",
" model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n",
" output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n",
" create_folder(global_params['output_dir'])\n",
" if global_params['save_model']:\n",
" torch.save(model_to_save.state_dict(), output_model_file)\n",
"\n",
" best_pre = aucreal\n",
" print(\"auc-mean: \", aucreal)\n",
" else:\n",
" if patience % 2 == 0 and patience != 0:\n",
" scheduler.step()\n",
" print(\"LR: \", scheduler.get_lr())\n",
"\n",
" patience = patience + 1\n",
" print('auprc : {}, auroc : {}, Treat-auc : {}, time: {}'.format(auc, auroc, auc2, \"long.....\"))\n",
"\n",
"\n",
"\n",
" LossC = 0.1\n",
" conf = BertConfig(model_config)\n",
" model = TBEHRT(conf, 1)\n",
" optim = optimizer.VAEadam(params=list(model.named_parameters()), config=optim_config)\n",
" output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n",
" model = toLoad(model, output_model_file)\n",
"\n",
"\n",
" y, y_label, t_label, treatO, tprc, eps = fullEval_4analysis_multi(False, True, testdata)\n",
"\n",
" y2, y2label, treatmentlabel2, treatment2 = fullCONV(y, y_label, t_label, treatO)\n",
"\n",
" NPSaveNAME = 'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".npz\"\n",
"\n",
" np.savez( NPSaveNAME,\n",
" outcome=y2,\n",
" outcome_label=y2label, treatment=treatment2, treatment_label=treatmentlabel2,\n",
" epsilon=np.array([0]))\n",
" del y, y_label, t_label, treatO, tprc, eps, y2, y2label, treatmentlabel2, treatment2, datatrain, conf, model, optim, output_model_file, best_pre, LossC,\n",
" print(\"\\n\\n\\n\\n\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "real3",
"language": "python",
"name": "py3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}