--- a +++ b/examples/run_TBEHRT.ipynb @@ -0,0 +1,931 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import sys\n", + "sys.path.insert(0, '/home/rnshishir/deepmed/TBEHRT_pl/')\n", + "\n", + "import os\n", + "from torch.utils.data import DataLoader\n", + "from sklearn.model_selection import KFold\n", + "import pandas as pd\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import pytorch_pretrained_bert as Bert\n", + "\n", + "from pytorch_pretrained_bert import optimizer\n", + "import sklearn.metrics as skm\n", + "from torch.utils.data.dataset import Dataset\n", + "from src.utils import *\n", + "from src.model import *\n", + "from src.data import *\n", + "\n", + "from torch import optim as toptimizer\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_beta(batch_idx, m, beta_type):\n", + " if beta_type == \"Blundell\":\n", + " beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)\n", + " elif beta_type == \"Soenderby\":\n", + " beta = min(epoch / (num_epochs // 4), 1)\n", + " elif beta_type == \"Standard\":\n", + " beta = 1 / m\n", + " else:\n", + " beta = 0\n", + " return beta\n", + "\n", + "\n", + "def trainunsup(e, sched, patienceMetric, MEM=True):\n", + " sampled = datatrain.reset_index(drop=True)\n", + " #\n", + "\n", + " Dset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n", + " age = 'age', year = 'year' , static= 'static' , \n", + " max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n", + " yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n", + " \n", + " \n", + " \n", + " trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n", + " sampler=None)\n", + "\n", + " model.train()\n", + " tr_loss = 0\n", + " temp_loss = 0\n", + " nb_tr_examples, nb_tr_steps = 0, 0\n", + " oldloss = 10 ** 10\n", + " for step, batch in enumerate(trainload):\n", + "\n", + " batch = tuple(t.to(global_params['device']) for t in batch)\n", + "\n", + " age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n", + "\n", + " masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n", + " input_idsMLM,\n", + " age_ids,\n", + " segment_ids,\n", + " posi_ids,\n", + " year_ids,\n", + "\n", + " attention_mask=attMask,\n", + " masked_lm_labels=masked_label,\n", + " outcomeT=outcome_label,\n", + " treatmentCLabel=treatment_label,\n", + " fullEval=False,\n", + " vaelabel=vaelabel)\n", + " vaeloss = vaelosspure['loss']\n", + "\n", + " totalL = masked_lm_loss\n", + " if global_params['gradient_accumulation_steps'] > 1:\n", + " totalL = totalL / global_params['gradient_accumulation_steps']\n", + " totalL.backward()\n", + " treatFull = treatOut\n", + " treatLabelFull = treatLabel\n", + " treatLabelFull = treatLabelFull.cpu().detach()\n", + "\n", + " outFull = out\n", + "\n", + " outLabelFull = outLabel\n", + " treatindex = treatindex.cpu().detach().numpy()\n", + " zeroind = np.where(treatindex == 0)\n", + " outzero = outFull[0][zeroind]\n", + " outzeroLabel = outLabelFull[zeroind]\n", + "\n", + "\n", + " temp_loss += totalL.item()\n", + " tr_loss += totalL.item()\n", + " nb_tr_examples += input_ids.size(0)\n", + " nb_tr_steps += 1\n", + "\n", + " if step % 600 == 0:\n", + " print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n", + " keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n", + " if oldloss < vaelosspure['loss']:\n", + " patienceMetric = patienceMetric + 1\n", + " if patienceMetric >= 10:\n", + " sched.step()\n", + " print(\"LR: \", sched.get_lr())\n", + " patienceMetric = 0\n", + " oldloss = vaelosspure['loss']\n", + "\n", + " if step % 200 == 0:\n", + " precOut0 = -1\n", + " if len(zeroind[0]) > 0:\n", + " precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n", + "\n", + " print(\n", + " \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n", + " e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n", + " cal_acc(treatLabelFull, treatFull, False)))\n", + " temp_loss = 0\n", + "\n", + " if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", + " optim.step()\n", + " optim.zero_grad()\n", + "\n", + " # Save a trained model\n", + " del sampled, Dset, trainload\n", + " return sched, patienceMetric\n", + "\n", + "\n", + "def train_multi(e, MEM=True):\n", + " sampled = datatrain.reset_index(drop=True)\n", + "\n", + " Dset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n", + " age = 'age', year = 'year' , static= 'static' , \n", + " max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n", + " yvocab=YearVocab['token2idx'], list2avoid=None, MEM=MEM)\n", + " \n", + " \n", + " \n", + " trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3,\n", + " sampler=None)\n", + " \n", + " \n", + " model.train()\n", + " tr_loss = 0\n", + " temp_loss = 0\n", + " nb_tr_examples, nb_tr_steps = 0, 0\n", + " for step, batch in enumerate(trainload):\n", + "\n", + " batch = tuple(t.to(global_params['device']) for t in batch)\n", + "\n", + " age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n", + " masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n", + " input_idsMLM,\n", + " age_ids,\n", + " segment_ids,\n", + " posi_ids,\n", + " year_ids,\n", + "\n", + " attention_mask=attMask,\n", + " masked_lm_labels=masked_label,\n", + " outcomeT=outcome_label,\n", + " treatmentCLabel=treatment_label,\n", + " fullEval=False,\n", + " vaelabel=vaelabel)\n", + "\n", + " vaeloss = vaelosspure['loss']\n", + " totalL = 1 * (lossT) + 0 + (global_params['fac'] * masked_lm_loss)\n", + " if global_params['gradient_accumulation_steps'] > 1:\n", + " totalL = totalL / global_params['gradient_accumulation_steps']\n", + " totalL.backward()\n", + " treatFull = treatOut\n", + " treatLabelFull = treatLabel\n", + " treatLabelFull = treatLabelFull.cpu().detach()\n", + "\n", + " outFull = out\n", + "\n", + " outLabelFull = outLabel\n", + " treatindex = treatindex.cpu().detach().numpy()\n", + " zeroind = np.where(treatindex == 0)\n", + " outzero = outFull[0][zeroind]\n", + " outzeroLabel = outLabelFull[zeroind]\n", + "\n", + " temp_loss += totalL.item()\n", + " tr_loss += totalL.item()\n", + " nb_tr_examples += input_ids.size(0)\n", + " nb_tr_steps += 1\n", + "\n", + " if step % 200 == 0:\n", + " precOut0 = -1\n", + "\n", + " if len(zeroind[0]) > 0:\n", + " precOut0, _, _ = OutcomePrecision(outzero, outzeroLabel, False)\n", + "\n", + " print(\n", + " \"epoch: {0}| Loss: {1:6.5f}\\t| MLM: {2:6.5f}\\t| TOutP: {3:6.5f}\\t|vaeloss: {4:6.5f}\\t|ExpP: {5:6.5f}\".format(\n", + " e, temp_loss / 200, cal_acc(label, pred), precOut0, vaeloss,\n", + " cal_acc(treatLabelFull, treatFull, False)))\n", + "\n", + " print([(keyvae, valvae) for (keyvae, valvae) in vaelosspure.items() if\n", + " keyvae in ['loss', 'Reconstruction_Loss', 'KLD']])\n", + " temp_loss = 0\n", + "\n", + " if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", + " optim.step()\n", + " optim.zero_grad()\n", + "\n", + " del sampled, Dset, trainload\n", + "\n", + "\n", + "def evaluation_multi_repeats():\n", + " model.eval()\n", + " y = []\n", + " y_label = []\n", + " t_label = []\n", + " t_output = []\n", + " count = 0\n", + " totalL = 0\n", + " for step, batch in enumerate(testload):\n", + " model.eval()\n", + " count = count + 1\n", + " batch = tuple(t.to(global_params['device']) for t in batch)\n", + "\n", + " age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n", + " with torch.no_grad():\n", + "\n", + " masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaelosspure = model(\n", + " input_idsMLM,\n", + " age_ids,\n", + " segment_ids,\n", + " posi_ids,\n", + " year_ids,\n", + "\n", + " attention_mask=attMask,\n", + " masked_lm_labels=masked_label,\n", + " outcomeT=outcome_label,\n", + " treatmentCLabel=treatment_label, vaelabel=vaelabel)\n", + "\n", + " totalL = totalL + lossT.item() + 0 + (global_params['fac'] * masked_lm_loss)\n", + " treatFull = treatOut\n", + " treatLabelFull = treatLabel\n", + " treatLabelFull = treatLabelFull.detach()\n", + " outFull = out\n", + " outLabelFull = outLabel\n", + " treatindex = treatindex.cpu().detach().numpy()\n", + " outPred = []\n", + " outexpLab = []\n", + " for el in range(global_params['treatments']):\n", + " zeroind = np.where(treatindex == el)\n", + " outPred.append(outFull[el][zeroind])\n", + " outexpLab.append(outLabelFull[zeroind])\n", + "\n", + "\n", + " y_label.append(torch.cat(outexpLab))\n", + "\n", + " y.append(torch.cat(outPred))\n", + "\n", + " treatOut = treatFull.cpu()\n", + " treatLabel = treatLabelFull.cpu()\n", + " if step % 200 == 0:\n", + " print(step, \"tempLoss:\", totalL / count)\n", + "\n", + " t_label.append(treatLabel)\n", + " t_output.append(treatOut)\n", + "\n", + " y_label = torch.cat(y_label, dim=0)\n", + " y = torch.cat(y, dim=0)\n", + " t_label = torch.cat(t_label, dim=0)\n", + " treatO = torch.cat(t_output, dim=0)\n", + "\n", + " tempprc, output, label = precision_test(y, y_label, False)\n", + " treatPRC = cal_acc(t_label, treatO, False)\n", + " tempprc2, output2, label2 = roc_auc(y, y_label, False)\n", + "\n", + " print(\"LossEval: \", float(totalL) / float(count))\n", + "\n", + " return tempprc, tempprc2, treatPRC, float(totalL) / float(count)\n", + "\n", + "\n", + "def fullEval_4analysis_multi(tr, te, filetest):\n", + " if tr:\n", + " sampled = datatrain.reset_index(drop=True)\n", + "\n", + " if te:\n", + " data = filetest\n", + "\n", + " if tr:\n", + " sampled = pd.concat([sampled, data]).reset_index(drop=True)\n", + " else:\n", + " sampled = data\n", + " Fulltset = TBEHRT_data_formation(BertVocab['token2idx'], sampled, code= 'code', \n", + " age = 'age', year = 'year' , static= 'static' , \n", + " max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n", + " yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n", + " \n", + " \n", + " \n", + " fullDataLoad = DataLoader(dataset=Fulltset, batch_size=int(global_params['batch_size']), shuffle=False,\n", + " num_workers=0)\n", + "\n", + " model.eval()\n", + " y = []\n", + " y_label = []\n", + " t_label = []\n", + " t_output = []\n", + " count = 0\n", + " totalL = 0\n", + " eps_array = []\n", + "\n", + " for yyy in range(model_config['num_treatment']):\n", + " y.append([yyy])\n", + " y_label.append([yyy])\n", + "\n", + " print(y)\n", + " for step, batch in enumerate(fullDataLoad):\n", + " model.eval()\n", + "\n", + " count = count + 1\n", + " batch = tuple(t.to(global_params['device']) for t in batch)\n", + "\n", + " age_ids, input_ids, input_idsMLM, posi_ids, segment_ids, year_ids, attMask, masked_label, outcome_label, treatment_label, vaelabel = batch\n", + "\n", + " with torch.no_grad():\n", + " masked_lm_loss, lossT, pred, label, treatOut, treatLabel, out, outLabel, treatindex, targreg, vaeloss = model(\n", + " input_idsMLM,\n", + " age_ids,\n", + " segment_ids,\n", + " posi_ids,\n", + " year_ids,\n", + "\n", + " attention_mask=attMask,\n", + " masked_lm_labels=masked_label,\n", + " outcomeT=outcome_label,\n", + " treatmentCLabel=treatment_label, fullEval=True, vaelabel=vaelabel)\n", + "\n", + "\n", + "\n", + " outFull = out\n", + " outLabelFull = outLabel\n", + "\n", + "\n", + " for el in range(global_params['treatments']):\n", + " y[el].append(outFull[el].cpu())\n", + " y_label[el].append(outLabelFull.cpu())\n", + "\n", + " totalL = totalL + (1 * (lossT)).item()\n", + "\n", + " if step % 200 == 0:\n", + " print(step, \"tempLoss:\", totalL / count)\n", + "\n", + " t_label.append(treatLabel)\n", + " t_output.append(treatOut)\n", + "\n", + " for idd, elem in enumerate(y):\n", + " elem = torch.cat(elem[1:], dim=0)\n", + " y[idd] = elem\n", + " for idd, elem in enumerate(y_label):\n", + " elem = torch.cat(elem[1:], dim=0)\n", + " y_label[idd] = elem\n", + "\n", + " t_label = torch.cat(t_label, dim=0)\n", + " treatO = torch.cat(t_output, dim=0)\n", + " treatPRC = cal_acc(t_label, treatO)\n", + "\n", + " print(\"LossEval: \", float(totalL) / float(count), \"prec treat:\", treatPRC)\n", + " return y, y_label, t_label, treatO, treatPRC, eps_array\n", + "\n", + "\n", + "def fullCONV(y, y_label, t_label, treatO):\n", + " def convert_multihot(label, pred):\n", + " label = label.cpu().numpy()\n", + " truepred = pred.detach().cpu().numpy()\n", + " truelabel = label\n", + " newpred = []\n", + " for i, x in enumerate(truelabel):\n", + " temppred = []\n", + " temppred.append(truepred[i][0])\n", + " temppred.append(truepred[i][x[0]])\n", + " newpred.append(temppred)\n", + " return truelabel, np.array(truepred)\n", + "\n", + " def convert_bin(logits, label, treatmentlabel2):\n", + "\n", + " output = logits\n", + " label, output = label.cpu().numpy(), output.detach().cpu().numpy()\n", + " label = label[treatmentlabel2[0]]\n", + "\n", + " return label, output\n", + "\n", + " treatmentlabel2, treatment2 = convert_multihot(t_label, treatO)\n", + " y = torch.cat(y, dim=0).view(global_params['treatments'], -1)\n", + " y = y.transpose(1, 0)\n", + " y_label = torch.cat(y_label, dim=0).view(global_params['treatments'], -1)\n", + " y_label = y_label.transpose(1, 0)\n", + " y2 = []\n", + " y2label = []\n", + " for i, elem in enumerate(y):\n", + " j, k = convert_bin(elem, y_label[i], treatmentlabel2[i])\n", + " y2.append(k)\n", + " y2label.append(j)\n", + " y2 = np.array(y2)\n", + " y2label = np.array(y2label)\n", + " y2label = np.expand_dims(y2label, -1)\n", + "\n", + " return y2, y2label, treatmentlabel2, treatment2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "file_config = {\n", + " 'data': 'test.parquet',\n", + "}\n", + "optim_config = {\n", + " 'lr': 1e-4,\n", + " 'warmup_proportion': 0.1\n", + "}\n", + "\n", + "\n", + "BertVocab = {}\n", + "token2idx = {'MASK': 4,\n", + " 'CLS': 3,\n", + " 'SEP': 2,\n", + " 'UNK': 1,\n", + " 'PAD': 0,\n", + " 'disease1':5,\n", + " 'disease2':6,\n", + " 'disease3':7,\n", + " 'disease4':8,\n", + " 'disease5':9,\n", + " 'disease6':10,\n", + " 'medication1':11,\n", + " 'medication2':12,\n", + " 'medication3':13,\n", + " 'medication4':14,\n", + " 'medication5':15,\n", + " 'medication6':16,\n", + " }\n", + "idx2token = {}\n", + "for x in token2idx:\n", + " idx2token[token2idx[x]]=x\n", + "BertVocab['token2idx']= token2idx\n", + "BertVocab['idx2token']= idx2token\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "YearVocab = {'token2idx': {'PAD': 0,\n", + " '1987': 1,\n", + " '1988': 2,\n", + " '1989': 3,\n", + " '1990': 4,\n", + " '1991': 5,\n", + " '1992': 6,\n", + " '1993': 7,\n", + " '1994': 8,\n", + " '1995': 9,\n", + " '1996': 10,\n", + " '1997': 11,\n", + " '1998': 12,\n", + " '1999': 13,\n", + " '2000': 14,\n", + " '2001': 15,\n", + " '2002': 16,\n", + " '2003': 17,\n", + " '2004': 18,\n", + " '2005': 19,\n", + " '2006': 20,\n", + " '2007': 21,\n", + " '2008': 22,\n", + " '2009': 23,\n", + " '2010': 24,\n", + " '2011': 25,\n", + " '2012': 26,\n", + " '2013': 27,\n", + " '2014': 28,\n", + " '2015': 29,\n", + " 'UNK': 30},\n", + " 'idx2token': {0: 'PAD',\n", + " 1: '1987',\n", + " 2: '1988',\n", + " 3: '1989',\n", + " 4: '1990',\n", + " 5: '1991',\n", + " 6: '1992',\n", + " 7: '1993',\n", + " 8: '1994',\n", + " 9: '1995',\n", + " 10: '1996',\n", + " 11: '1997',\n", + " 12: '1998',\n", + " 13: '1999',\n", + " 14: '2000',\n", + " 15: '2001',\n", + " 16: '2002',\n", + " 17: '2003',\n", + " 18: '2004',\n", + " 19: '2005',\n", + " 20: '2006',\n", + " 21: '2007',\n", + " 22: '2008',\n", + " 23: '2009',\n", + " 24: '2010',\n", + " 25: '2011',\n", + " 26: '2012',\n", + " 27: '2013',\n", + " 28: '2014',\n", + " 29: '2015',\n", + " 30: 'UNK'}}\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "global_params = {\n", + " 'batch_size': 128,\n", + " 'gradient_accumulation_steps': 1,\n", + " 'num_train_epochs': 3,\n", + " 'device': 'cuda:0',\n", + " 'output_dir': \"save_models\",\n", + " 'save_model': True,\n", + " 'max_len_seq': 250,\n", + " 'max_age': 110,\n", + " 'age_year': False,\n", + " 'age_symbol': None,\n", + " 'fac': 0.1,\n", + " 'diseaseI': 1,\n", + " 'treatments': 2\n", + "}\n", + "\n", + "ageVocab, _ = age_vocab(max_age=global_params['max_age'], year=global_params['age_year'],\n", + " symbol=global_params['age_symbol'])\n", + "\n", + "model_config = {\n", + " 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n", + " 'hidden_size': 150, # word embedding and seg embedding hidden size\n", + " 'seg_vocab_size': 2, # number of vocab for seg embedding\n", + " 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n", + " 'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n", + " 'hidden_dropout_prob': 0.3, # dropout rate\n", + " 'num_hidden_layers': 4, # number of multi-head attention layers required\n", + " 'num_attention_heads': 6, # number of attention heads\n", + " 'attention_probs_dropout_prob': 0.4, # multi-head attention dropout rate\n", + " 'intermediate_size': 108, # the size of the \"intermediate\" layer in the transformer encoder\n", + " 'hidden_act': 'gelu',\n", + " 'initializer_range': 0.02, # parameter weight initializer range\n", + " 'num_treatment': global_params['treatments'],\n", + " 'device': global_params['device'],\n", + " 'year_vocab_size': len(YearVocab['token2idx'].keys()),\n", + "\n", + " 'batch_size': global_params['batch_size'],\n", + " 'MEM': True,\n", + " 'poolingSize': 50,\n", + " 'unsupVAE': True,\n", + " 'unsupSize': ([[3,2]] *22) ,\n", + " 'vaelatentdim': 40,\n", + " 'vaehidden': 50,\n", + " 'vaeinchannels':39,\n", + "\n", + "\n", + "\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Begin experiments....\n", + "_________________\n", + "fold___0\n", + "_________________\n", + "turning on the MEM....\n", + "full init completed...\n", + "[('loss', tensor(6.4376, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(823.7451, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.0632, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 0| Loss: 0.04597\t| MLM: 0.13298\t| TOutP: 0.30261\t|vaeloss: 6.43762\t|ExpP: 0.76562\n", + "[('loss', tensor(6.0075, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(768.0154, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-17.6470, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 1| Loss: 0.03695\t| MLM: 0.35556\t| TOutP: 0.20014\t|vaeloss: 6.00750\t|ExpP: 0.63281\n", + "epoch: 0| Loss: 0.00941\t| MLM: 0.49010\t| TOutP: 0.26410\t|vaeloss: 3.98089\t|ExpP: 0.75781\n", + "[('loss', tensor(3.9809, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(450.0008, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-1108.7188, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.9582, device='cuda:0')\n", + "LossEval: 0.912086296081543\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.912086296081543\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "epoch: 1| Loss: 0.00508\t| MLM: 0.56186\t| TOutP: 0.17931\t|vaeloss: 1.55077\t|ExpP: 0.99219\n", + "[('loss', tensor(1.5508, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(56.6759, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2628.6050, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.8717, device='cuda:0')\n", + "LossEval: 0.75910964012146\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.75910964012146\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "turning on the MEM....\n", + "full init completed...\n", + "[[0], [1]]\n", + "0 tempLoss: 0.7803983688354492\n", + "LossEval: 0.6677856683731079 prec treat: 0.9766666666666667\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_________________\n", + "fold___1\n", + "_________________\n", + "turning on the MEM....\n", + "full init completed...\n", + "[('loss', tensor(6.3858, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(817.0728, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.8979, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 0| Loss: 0.04601\t| MLM: 0.11060\t| TOutP: 0.15367\t|vaeloss: 6.38584\t|ExpP: 0.42188\n", + "[('loss', tensor(6.0483, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(773.4330, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-13.9553, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 1| Loss: 0.03693\t| MLM: 0.38764\t| TOutP: 0.28702\t|vaeloss: 6.04828\t|ExpP: 0.23438\n", + "epoch: 0| Loss: 0.00949\t| MLM: 0.48663\t| TOutP: 0.33637\t|vaeloss: 4.00320\t|ExpP: 0.39844\n", + "[('loss', tensor(4.0032, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(461.0717, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-955.7684, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.9395, device='cuda:0')\n", + "LossEval: 0.8925997734069824\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.8925997734069824\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "epoch: 1| Loss: 0.00531\t| MLM: 0.55825\t| TOutP: 0.24316\t|vaeloss: 1.45487\t|ExpP: 0.95312\n", + "[('loss', tensor(1.4549, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(58.0446, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2375.7451, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.8748, device='cuda:0')\n", + "LossEval: 0.7546854019165039\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.7546854019165039\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "turning on the MEM....\n", + "full init completed...\n", + "[[0], [1]]\n", + "0 tempLoss: 0.7930977940559387\n", + "LossEval: 0.6729061722755432 prec treat: 0.9766666666666667\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_________________\n", + "fold___2\n", + "_________________\n", + "turning on the MEM....\n", + "full init completed...\n", + "[('loss', tensor(6.8586, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(877.5925, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.7485, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 0| Loss: 0.04843\t| MLM: 0.04945\t| TOutP: 0.24977\t|vaeloss: 6.85859\t|ExpP: 0.81250\n", + "[('loss', tensor(6.1659, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(788.5856, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-12.1737, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 1| Loss: 0.03748\t| MLM: 0.35294\t| TOutP: 0.24342\t|vaeloss: 6.16592\t|ExpP: 0.74219\n", + "epoch: 0| Loss: 0.00960\t| MLM: 0.51531\t| TOutP: 0.13057\t|vaeloss: 4.42672\t|ExpP: 0.83594\n", + "[('loss', tensor(4.4267, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(515.2249, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-956.8342, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.9165, device='cuda:0')\n", + "LossEval: 0.8629927635192871\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.8629927635192871\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "epoch: 1| Loss: 0.00494\t| MLM: 0.63473\t| TOutP: 0.20097\t|vaeloss: 1.60833\t|ExpP: 0.97656\n", + "[('loss', tensor(1.6083, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(76.9737, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2388.9741, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.8581, device='cuda:0')\n", + "LossEval: 0.7425572872161865\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.7425572872161865\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "turning on the MEM....\n", + "full init completed...\n", + "[[0], [1]]\n", + "0 tempLoss: 0.7719957232475281\n", + "LossEval: 0.6564249277114869 prec treat: 0.9766666666666667\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_________________\n", + "fold___3\n", + "_________________\n", + "turning on the MEM....\n", + "full init completed...\n", + "[('loss', tensor(6.4421, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(824.2659, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-6.0783, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 0| Loss: 0.04659\t| MLM: 0.05825\t| TOutP: 0.23844\t|vaeloss: 6.44211\t|ExpP: 0.56250\n", + "[('loss', tensor(6.5312, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(835.7316, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-4.9471, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 1| Loss: 0.03936\t| MLM: 0.37313\t| TOutP: 0.25152\t|vaeloss: 6.53122\t|ExpP: 0.38281\n", + "epoch: 0| Loss: 0.00999\t| MLM: 0.40314\t| TOutP: 0.28174\t|vaeloss: 4.91041\t|ExpP: 0.42969\n", + "[('loss', tensor(4.9104, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(588.9692, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-736.5710, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.9352, device='cuda:0')\n", + "LossEval: 0.888219165802002\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.888219165802002\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "epoch: 1| Loss: 0.00542\t| MLM: 0.46328\t| TOutP: 0.16435\t|vaeloss: 1.87306\t|ExpP: 0.98438\n", + "[('loss', tensor(1.8731, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(115.1043, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2310.2729, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.8745, device='cuda:0')\n", + "LossEval: 0.759437370300293\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.759437370300293\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "turning on the MEM....\n", + "full init completed...\n", + "[[0], [1]]\n", + "0 tempLoss: 0.7776303291320801\n", + "LossEval: 0.6625770092010498 prec treat: 0.9766666666666667\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "_________________\n", + "fold___4\n", + "_________________\n", + "turning on the MEM....\n", + "full init completed...\n", + "[('loss', tensor(7.1006, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(908.5834, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-5.5566, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 0| Loss: 0.04973\t| MLM: 0.08920\t| TOutP: 0.34168\t|vaeloss: 7.10062\t|ExpP: 0.75000\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[('loss', tensor(6.3710, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(815.0780, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-7.6161, device='cuda:0', grad_fn=<NegBackward>))]\n", + "epoch: 1| Loss: 0.03872\t| MLM: 0.42614\t| TOutP: 0.17729\t|vaeloss: 6.37098\t|ExpP: 0.48438\n", + "epoch: 0| Loss: 0.00975\t| MLM: 0.39409\t| TOutP: 0.24371\t|vaeloss: 4.48545\t|ExpP: 0.45312\n", + "[('loss', tensor(4.4854, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(531.1162, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-800.9330, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.9520, device='cuda:0')\n", + "LossEval: 0.9080421447753906\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.9080421447753906\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "epoch: 1| Loss: 0.00529\t| MLM: 0.64242\t| TOutP: 0.28668\t|vaeloss: 1.59132\t|ExpP: 0.99219\n", + "[('loss', tensor(1.5913, device='cuda:0', grad_fn=<DivBackward0>)), ('Reconstruction_Loss', tensor(74.8522, device='cuda:0', grad_fn=<AddBackward0>)), ('KLD', tensor(-2387.9209, device='cuda:0', grad_fn=<NegBackward>))]\n", + "0 tempLoss: tensor(0.8671, device='cuda:0')\n", + "LossEval: 0.7486227989196778\n", + "** ** * Saving best fine - tuned model ** ** * \n", + "auc-mean: -0.7486227989196778\n", + "auprc : 0.21718102508178844, auroc : 0.4948404108139781, Treat-auc : 0.9766666666666667, time: long.....\n", + "turning on the MEM....\n", + "full init completed...\n", + "[[0], [1]]\n", + "0 tempLoss: 0.7893564105033875\n", + "LossEval: 0.670855188369751 prec treat: 0.9766666666666667\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "\n", + "data = pd.read_parquet (file_config['data'])\n", + "\n", + "kf = KFold(n_splits = 5, shuffle = True, random_state = 2)\n", + "\n", + "print('Begin experiments....')\n", + "\n", + "\n", + "\n", + "for cutiter in (range(5)):\n", + " print(\"_________________\\nfold___\" + str(cutiter) + \"\\n_________________\")\n", + " data = pd.read_parquet (file_config['data'])\n", + "\n", + " result = next(kf.split(data), None)\n", + "\n", + " datatrain = data.iloc[result[0]].reset_index(drop=True)\n", + " testdata = data.iloc[result[1]].reset_index(drop=True)\n", + "\n", + " tset = TBEHRT_data_formation(BertVocab['token2idx'], testdata, code= 'code', \n", + " age = 'age', year = 'year' , static= 'static' , \n", + " max_len=global_params['max_len_seq'],expColumn='explabel', outcomeColumn='label', \n", + " yvocab=YearVocab['token2idx'], list2avoid=None, MEM=False)\n", + " \n", + " \n", + " \n", + " testload = DataLoader(dataset=tset, batch_size=int(global_params['batch_size']), shuffle=False, num_workers=0)\n", + "\n", + "\n", + " model_config['klpar']= float(1.0/(len(datatrain)/global_params['batch_size']))\n", + " conf = BertConfig(model_config)\n", + " model = TBEHRT(conf, 1)\n", + "\n", + " optim = optimizer.adam(params=list(model.named_parameters()), config=optim_config)\n", + "\n", + " model_to_save_name = 'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".bin\"\n", + "\n", + " import warnings\n", + "\n", + " warnings.filterwarnings(action='ignore')\n", + " scheduler = toptimizer.lr_scheduler.ExponentialLR(optim, 0.95, last_epoch=-1)\n", + " patience = 0\n", + " best_pre = -100000000000000000000\n", + " LossC = 0.1\n", + " #\n", + " for e in range(2):\n", + " scheduler , patience= trainunsup(e, scheduler, patience)\n", + "\n", + " for e in range(2):\n", + " train_multi(e)\n", + " auc, auroc, auc2, loss = evaluation_multi_repeats()\n", + " aucreal = -1 * loss\n", + " if aucreal > best_pre:\n", + " patience = 0\n", + " # Save a trained model\n", + " print(\"** ** * Saving best fine - tuned model ** ** * \")\n", + " model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n", + " output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n", + " create_folder(global_params['output_dir'])\n", + " if global_params['save_model']:\n", + " torch.save(model_to_save.state_dict(), output_model_file)\n", + "\n", + " best_pre = aucreal\n", + " print(\"auc-mean: \", aucreal)\n", + " else:\n", + " if patience % 2 == 0 and patience != 0:\n", + " scheduler.step()\n", + " print(\"LR: \", scheduler.get_lr())\n", + "\n", + " patience = patience + 1\n", + " print('auprc : {}, auroc : {}, Treat-auc : {}, time: {}'.format(auc, auroc, auc2, \"long.....\"))\n", + "\n", + "\n", + "\n", + " LossC = 0.1\n", + " conf = BertConfig(model_config)\n", + " model = TBEHRT(conf, 1)\n", + " optim = optimizer.VAEadam(params=list(model.named_parameters()), config=optim_config)\n", + " output_model_file = os.path.join(global_params['output_dir'], model_to_save_name)\n", + " model = toLoad(model, output_model_file)\n", + "\n", + "\n", + " y, y_label, t_label, treatO, tprc, eps = fullEval_4analysis_multi(False, True, testdata)\n", + "\n", + " y2, y2label, treatmentlabel2, treatment2 = fullCONV(y, y_label, t_label, treatO)\n", + "\n", + " NPSaveNAME = 'TBEHRT_Test' + \"__CUT\" + str(cutiter) + \".npz\"\n", + "\n", + " np.savez( NPSaveNAME,\n", + " outcome=y2,\n", + " outcome_label=y2label, treatment=treatment2, treatment_label=treatmentlabel2,\n", + " epsilon=np.array([0]))\n", + " del y, y_label, t_label, treatO, tprc, eps, y2, y2label, treatmentlabel2, treatment2, datatrain, conf, model, optim, output_model_file, best_pre, LossC,\n", + " print(\"\\n\\n\\n\\n\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "real3", + "language": "python", + "name": "py3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}