--- a +++ b/patient/health_risk_prediction.py @@ -0,0 +1,378 @@ +import argparse +import os +import time + +import pandas as pd +import torch +from tqdm import tqdm +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, cohen_kappa_score +from torch.optim import * +from sklearn.metrics import precision_recall_curve, auc +from models.og_dataset import * +from models.baseline import * +from utils.utils import check_path, export_config, bool_flag +import csv +import os +from datetime import datetime +from models.adacare import AdaCare + +def eval_metric(eval_set, model, encoder): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.eval() + with torch.no_grad(): + y_true = np.array([]) + y_pred = np.array([]) + y_score = np.array([]) + for i, data in enumerate(eval_set): + labels, ehr, mask, txt, mask_txt, lengths, time_step, code_mask = data + if encoder == 'adacare': + logits = model(ehr, device) + else: + + logits = model(ehr, mask, lengths, time_step, code_mask) + scores = torch.softmax(logits, dim=-1) + scores = scores.data.cpu().numpy() + labels = labels.data.cpu().numpy() + score = scores[:, 1] + pred = scores.argmax(1) + y_true = np.concatenate((y_true, labels)) + y_pred = np.concatenate((y_pred, pred)) + y_score = np.concatenate((y_score, score)) + accuary = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred) + recall = recall_score(y_true, y_pred) + f1 = f1_score(y_true, y_pred) + roc_auc = roc_auc_score(y_true, y_score) + lr_precision, lr_recall, _ = precision_recall_curve(y_true, y_score) + pr_auc = auc(lr_recall, lr_precision) + kappa = cohen_kappa_score(y_true, y_pred) + + return accuary, precision, recall, f1, roc_auc, pr_auc, kappa + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--cuda', default=True, type=bool_flag, nargs='?', const=True, help='use GPU') + parser.add_argument('--seed', default=0, type=int, help='seed') + parser.add_argument('-bs', '--batch_size', default=32, type=int) + parser.add_argument('-me', '--max_epochs_before_stop', default=10, type=int) + parser.add_argument('--encoder', default='hita', choices=['hita', 'lstm', 'gruself', 'retain', 'adacare']) + parser.add_argument('--d_model', default=256, type=int, help='dimension of hidden layers') + parser.add_argument('--dropout', default=0.8, type=float, help='dropout rate of hidden layers') + parser.add_argument('--dropout_emb', default=0.8, type=float, help='dropout rate of embedding layers') + parser.add_argument('--num_layers', default=2, type=int, help='number of transformer layers of EHR encoder') + parser.add_argument('--num_heads', default=4, type=int, help='number of attention heads') + parser.add_argument('--max_len', default=50, type=int, help='max visits of EHR') + parser.add_argument('--max_num_codes', default=20, type=int, help='max number of ICD codes in each visit') + parser.add_argument('--max_num_blks', default=120, type=int, help='max number of blocks in each visit') + parser.add_argument('--blk_emb_path', default='./data/processed/block_embedding.npy', + help='embedding path of blocks') + parser.add_argument('--blk_vocab_path', default='./data/processed/block_vocab.txt') + parser.add_argument('--target_disease', default='Heart_failure', choices=['Heart_failure', 'COPD', 'Kidney', 'Dementia', 'Amnesia', 'mimic']) + parser.add_argument('--target_att_heads', default=4, type=int, help='target disease attention heads number') + parser.add_argument('--mem_size', default=20, type=int, help='memory size') + parser.add_argument('--mem_update_size', default=15, type=int, help='memory update size') + parser.add_argument('-lr', '--learning_rate', default=0.00001, type=float, help='learning rate') + parser.add_argument('--weight_decay', default=0.01, type=float) + parser.add_argument('--max_grad_norm', default=1.0, type=float, help='max grad norm (0 to disable)') + parser.add_argument('--warmup_steps', default=200, type=int) + parser.add_argument('--n_epochs', default=30, type=int) + parser.add_argument('--log_interval', default=20, type=int) + parser.add_argument('--mode', default='train', choices=['train', 'eval', 'pred','gen'], help='run training or evaluation') + parser.add_argument('--save_dir', default='./saved_models/', help='model output directory') + parser.add_argument('--pretrain', default=False, help='flag for using pretrained model') + args = parser.parse_args() + if args.mode == 'train': + train(args) + elif args.mode == 'pred': + pred(args) + else: + raise ValueError('Invalid mode') + + +def train(args): + print(args) + path = "patient_level/" + files = os.listdir(path) + print(args) + if str(args.pretrain) + '_' + str(args.target_disease) + '_' + str(args.encoder) + '_' + '.csv' in files: + print("conducted_experiments") + else: + config_path = os.path.join(args.save_dir, 'config.json') + model_path = os.path.join(args.save_dir, 'model.pt') + log_path = os.path.join(args.save_dir, 'log.csv') + export_config(args, config_path) + check_path(model_path) + with open(log_path, 'w') as fout: + fout.write('step,train_auc,dev_auc,test_auc\n') + + blk_emb = np.load(args.blk_emb_path) + blk_pad_id = len(blk_emb) - 1 + if args.target_disease == 'Heart_failure': + code2id = pickle.load(open('./data/hf/hf_code2idx.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/hf/hf' + emb_path = './data/processed/heart_failure.npy' + elif args.target_disease == 'COPD': + code2id = pickle.load(open('./data/copd/copd_code2idx.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/copd/copd' + emb_path = './data/processed/COPD.npy' + elif args.target_disease == 'Kidney': + code2id = pickle.load(open('./data/kidney/kidney_code2idx.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/kidney/kidney' + emb_path = './data/processed/kidney_disease.npy' + elif args.target_disease == 'Dementia': + code2id = pickle.load(open('./data/dementia/dementia_code2idx.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/dementia/dementia' + emb_path = './data/processed/dementia.npy' + elif args.target_disease == 'Amnesia': + code2id = pickle.load(open('./data/amnesia/amnesia_code2idx.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/amnesia/amnesia' + emb_path = './data/processed/amnesia.npy' + elif args.target_disease == 'mimic': + code2id = pickle.load(open('./data/mimic/mimic_code2idx_sps.pickle', 'rb')) + pad_id = len(code2id)+1 + data_path = './data/mimic/mimic' + emb_path = './data/processed/mimic.npy' + else: + raise ValueError('Invalid disease') + + torch.manual_seed(args.seed) + device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") + if args.target_disease == 'mimic': + train_dataset = MyDataset3(data_path + '_training_sps.pickle', + args.max_len, args.max_num_codes, args.max_num_blks, pad_id, device) + dev_dataset = MyDataset3(data_path + '_validation_sps.pickle', + args.max_len, + args.max_num_codes, args.max_num_blks, pad_id, device) + test_dataset = MyDataset3(data_path + '_testing_sps.pickle', args.max_len, + args.max_num_codes, args.max_num_blks, pad_id, device) + train_dataloader = DataLoader(train_dataset, args.batch_size, shuffle=True, collate_fn=collate_fn) + dev_dataloader = DataLoader(dev_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + else: + train_dataset = MyDataset2(data_path + '_training_sps.pickle', data_path + '_training_txt.pickle', + args.max_len, args.max_num_codes, args.max_num_blks, pad_id, blk_pad_id, device) + dev_dataset = MyDataset2(data_path + '_validation_sps.pickle', data_path + '_validation_txt.pickle', args.max_len, + args.max_num_codes, args.max_num_blks, pad_id, blk_pad_id, device) + test_dataset = MyDataset2(data_path + '_testing_sps.pickle', data_path + '_testing_txt.pickle', args.max_len, + args.max_num_codes, args.max_num_blks, pad_id, blk_pad_id, device) + train_dataloader = DataLoader(train_dataset, args.batch_size, shuffle=True, collate_fn=collate_fn) + dev_dataloader = DataLoader(dev_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + + if args.encoder == 'hita': + model = HitaNet(pad_id, args.d_model, args.dropout, args.dropout_emb, args.num_layers, args.num_heads, + args.max_len, pretrain = args.pretrain) + elif args.encoder == 'lstm': + model = LSTM_encoder(pad_id, args.d_model, args.dropout, args.dropout_emb, args.num_layers, args.num_heads, + args.max_len, pretrain = args.pretrain) + elif args.encoder == 'gruself': + model = GRUSelf(pad_id, args.d_model, args.dropout, args.dropout_emb, args.num_layers, args.num_heads, + args.max_len, pretrain = args.pretrain) + elif args.encoder == 'retain': + model = Retain(pad_id, args.d_model, args.dropout, args.dropout_emb, args.num_layers, args.num_heads, + args.max_len, pretrain = args.pretrain) + elif args.encoder == 'adacare': + model = AdaCare(pad_id, hidden_dim=args.d_model, kernel_size=2, kernel_num=64, input_dim=args.d_model, output_dim=1, dropout=args.dropout, r_v=4, + r_c=4, activation='sigmoid', pretrain = args.pretrain) + else: + raise ValueError('Invalid encoder') + model.to(device) + + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + 'weight_decay': args.weight_decay, 'lr': args.learning_rate}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + 'weight_decay': 0.0, 'lr': args.learning_rate} + ] + optim = Adam(grouped_parameters) + loss_func = nn.CrossEntropyLoss(reduction='mean') + + print('parameters:') + for name, param in model.named_parameters(): + if param.requires_grad: + print('\t{:45}\ttrainable\t{}'.format(name, param.size())) + else: + print('\t{:45}\tfixed\t{}'.format(name, param.size())) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('\ttotal:', num_params) + print() + print('-' * 71) + global_step, best_dev_epoch = 0, 0 + best_dev_auc, final_test_auc, total_loss = 0.0, 0.0, 0.0 + best_kappa, best_f1 = 0,0 + model.train() + for epoch_id in range(args.n_epochs): + print('epoch: {:5} '.format(epoch_id)) + model.train() + start_time = time.time() + for i, data in enumerate(train_dataloader): + labels, ehr, mask, txt, mask_txt, lengths, time_step, code_mask = data + optim.zero_grad() + if args.encoder == 'adacare': + outputs = model(ehr, device) + else: + outputs = model(ehr, mask, lengths, time_step, code_mask) + loss = loss_func(outputs, labels) + loss.backward() + total_loss += (loss.item() / labels.size(0)) * args.batch_size + if args.max_grad_norm > 0: + nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optim.step() + if (global_step + 1) % args.log_interval == 0: + total_loss /= args.log_interval + ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval + print('| step {:5} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, + total_loss, + ms_per_batch)) + total_loss = 0.0 + start_time = time.time() + global_step += 1 + + model.eval() + train_acc, tr_precision, tr_recall, tr_f1, tr_roc_auc, tr_pr_auc, tr_kappa = eval_metric(train_dataloader, model, args.encoder) + dev_acc, d_precision, d_recall, d_f1, d_roc_auc, d_pr_auc, d_kappa = eval_metric(dev_dataloader, model, args.encoder) + test_acc, t_precision, t_recall, t_f1, t_roc_auc, t_pr_auc, t_kappa = eval_metric(test_dataloader, model, args.encoder) + print('-' * 71) + print('| step {:5} | train_acc {:7.4f} | dev_acc {:7.4f} | test_acc {:7.4f} '.format(global_step, + train_acc, + dev_acc, + test_acc)) + print( + '| step {:5} | train_precision {:7.4f} | dev_precision {:7.4f} | test_precision {:7.4f} '.format( + global_step, + tr_precision, + d_precision, + t_precision)) + print('| step {:5} | train_recall {:7.4f} | dev_recall {:7.4f} | test_recall {:7.4f} '.format( + global_step, + tr_recall, + d_recall, + t_recall)) + print('| step {:5} | train_f1 {:7.4f} | dev_f1 {:7.4f} | test_f1 {:7.4f} '.format(global_step, + tr_f1, + d_f1, + t_f1)) + print('| step {:5} | train_auc {:7.4f} | dev_auc {:7.4f} | test_auc {:7.4f} '.format(global_step, + tr_roc_auc, + d_roc_auc, + t_roc_auc)) + print('| step {:5} | train_pr {:7.4f} | dev_pr {:7.4f} | test_pr {:7.4f} '.format(global_step, + tr_pr_auc, + d_pr_auc, + t_pr_auc)) + print('| step {:5} | train_kappa {:7.4f} | dev_kappa {:7.4f} | test_kappa {:7.4f} '.format(global_step, + tr_kappa, + d_kappa, + t_kappa)) + print('-' * 71) + + if d_f1 >= best_dev_auc: + best_dev_auc = d_f1 + final_test_auc = t_pr_auc + best_dev_epoch = epoch_id + best_f1 = t_f1 + best_kappa = t_kappa + torch.save([model, args], model_path) + with open(log_path, 'a') as fout: + fout.write('{},{},{},{}, {},{}\n'.format(global_step, tr_pr_auc, d_pr_auc, t_pr_auc, t_f1, t_kappa)) + print(f'model saved to {model_path}') + if epoch_id - best_dev_epoch >= args.max_epochs_before_stop: + + break + + now = datetime.now() + results_file = open(path+str(args.pretrain) + '_' + str(args.target_disease) + '_' + str(args.encoder) + '_' + '.csv','w',encoding = 'gbk') + csv_w = csv.writer(results_file) + csv_w.writerow([final_test_auc, best_f1, best_kappa]) + print() + print('training ends in {} steps'.format(global_step)) + print('best dev auc: {:.4f} (at epoch {})'.format(best_dev_auc, best_dev_epoch)) + print('final test auc: {:.4f}'.format(final_test_auc)) + print(final_test_auc, best_f1, best_kappa) + print() + results_file.close() + + +def pred(args): + model_path = os.path.join(args.save_dir, 'model.pt') + model, old_args = torch.load(model_path) + device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu") + model.to(device) + model.eval() + blk_emb = np.load(old_args.blk_emb_path) + blk_pad_id = len(blk_emb) - 1 + if old_args.target_disease == 'Heart_failure': + code2id = pickle.load(open('./data/hf/hf_code2idx.pickle', 'rb')) + id2code = {int(v): k for k, v in code2id.items()} + code2topic = pickle.load(open('./data/hf/hf_code2topic.pickle', 'rb')) + pad_id = len(code2id) + data_path = './data/hf/hf' + elif old_args.target_disease == 'COPD': + code2id = pickle.load(open('./data/copd/copd_code2idx.pickle', 'rb')) + id2code = {int(v): k for k, v in code2id.items()} + code2topic = pickle.load(open('./data/copd/copd_code2topic.pickle', 'rb')) + pad_id = len(code2id) + data_path = './data/copd/copd' + elif old_args.target_disease == 'Kidney': + code2id = pickle.load(open('./data/kidney/kidney_code2idx.pickle', 'rb')) + id2code = {int(v): k for k, v in code2id.items()} + code2topic = pickle.load(open('./data/kidney/kidney_code2topic.pickle', 'rb')) + pad_id = len(code2id) + data_path = './data/kidney/kidney' + elif old_args.target_disease == 'Amnesia': + code2id = pickle.load(open('./data/amnesia/amnesia_code2idx.pickle', 'rb')) + id2code = {int(v): k for k, v in code2id.items()} + code2topic = pickle.load(open('./data/amnesia/amnesia_code2topic.pickle', 'rb')) + pad_id = len(code2id) + data_path = './data/amnesia/amnesia' + elif old_args.target_disease == 'Dementia': + code2id = pickle.load(open('./data/dementia/dementia_code2idx.pickle', 'rb')) + id2code = {int(v): k for k, v in code2id.items()} + code2topic = pickle.load(open('./data/dementia/dementia_code2topic.pickle', 'rb')) + pad_id = len(code2id) + data_path = './data/dementia/dementia' + else: + raise ValueError('Invalid disease') + dev_dataset = MyDataset(data_path + '_validation_sps.pickle', data_path + '_validation_txt.pickle', + old_args.max_len, old_args.max_num_codes, old_args.max_num_blks, pad_id, blk_pad_id, device) + test_dataset = MyDataset(data_path + '_testing_sps.pickle', data_path + '_testing_txt.pickle', old_args.max_len, + old_args.max_num_codes, old_args.max_num_blks, pad_id, blk_pad_id, device) + dev_dataloader = DataLoader(dev_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False, collate_fn=collate_fn) + # dev_acc, d_precision, d_recall, d_f1, d_roc_auc, d_pr_auc = eval_metric(dev_dataloader, model) + test_acc, t_precision, t_recall, t_f1, t_roc_auc, t_pr_auc, t_kappa = eval_metric(test_dataloader, model) + log_path = os.path.join(args.save_dir, 'result.csv') + with open(log_path, 'w') as fout: + fout.write('test_auc,test_f1,test_pre,test_recall,test_pr_auc,test_kappa\n') + fout.write( + '{},{},{},{},{},{}\n'.format(t_roc_auc, t_f1, t_precision, t_recall, t_pr_auc, t_kappa)) + with torch.no_grad(): + y_true = np.array([]) + y_pred = np.array([]) + y_score = np.array([]) + for i, data in enumerate(test_dataloader): + labels, ehr, mask, txt, mask_txt, lengths, time_step, code_mask = data + logits = model(ehr, mask, lengths, time_step, code_mask) + scores = torch.softmax(logits, dim=-1) + scores = scores.data.cpu().numpy() + labels = labels.data.cpu().numpy() + score = scores[:, 1] + pred = scores.argmax(1) + y_true = np.concatenate((y_true, labels)) + y_pred = np.concatenate((y_pred, pred)) + y_score = np.concatenate((y_score, score)) + with open(os.path.join(args.save_dir, 'prediction.csv'), 'w') as fout2: + fout2.write('prediciton,score,label\n') + for i in range(len(y_true)): + fout2.write('{},{},{}\n'.format(y_pred[i], y_score[i], y_true[i])) + + +if __name__ == '__main__': + main() + globals().clear() \ No newline at end of file