--- a +++ b/stay_admission/admission_downstream.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +import json +import pickle +import pandas as pd +import numpy as np +import sparse +import torch +import model +from tqdm import tqdm +from torch import nn, optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +import sklearn +from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, cohen_kappa_score, f1_score +import matplotlib.pyplot as plt +from baseline import * +from operations import * +from train import * +from model import * +import sys +import csv +import os + + + + + + +print(sys.argv[2:]) +EPOCHS,LR, BATCH, SEED, TASK, DEVICE = sys.argv[2:] +EPOCHS,LR, BATCH, SEED, TASK, DEVICE = int(EPOCHS), float(LR), int(BATCH), int(SEED), str(TASK), str(DEVICE) + +PATH = "hmp_results/" + TASK + '/' + +if str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv' in os.listdir(PATH): + print("conducted experiments") +else: + + + device = torch.device(DEVICE) + + data = pickle.load(open("data/hmp_admission.p", 'rb')) + + + split_mark = int(len(data)*0.8), int(len(data)*0.9) + + + + def collate_batch(batch_data): + + + + icd = torch.tensor([i[0] for i in batch_data]).to(torch.float32).to(device) + drug = torch.tensor([i[1] for i in batch_data]).to(torch.float32).to(device) + X = torch.tensor(np.array([np.stack(i[2], axis = 0) for i in batch_data])).to(torch.float32).to(device) + S = torch.tensor(np.array([np.stack(i[3], axis = 0) for i in batch_data])).to(torch.float32).to(device) + input_ids = torch.stack([i[4] for i in batch_data]).to(device) + attention_mask = torch.stack([i[5] for i in batch_data]).to(device) + token_type_ids = torch.stack([i[6] for i in batch_data]).to(device) + label = torch.tensor(np.array([i[-1] for i in batch_data])).to(torch.float32).to(device) + return [icd , drug, X, S, input_ids, attention_mask, token_type_ids, label] + def collate_batch_ts(batch_data): + + + + X = torch.tensor(np.array([i[0] for i in batch_data])).to(torch.float32).to(device) + label = torch.tensor(np.array([i[4] for i in batch_data])).to(torch.float32).to(device) + return [X, label] + + + + + + + + + + + # multimodal encoder evaluation + + test = DataLoader(data[split_mark[1]:], batch_size = BATCH, shuffle = True, collate_fn=collate_batch) + train = DataLoader(data[:split_mark[0]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch) + valid = DataLoader(data[split_mark[0]:split_mark[1]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch) + + + + MedHMP = HADM_CLS(7686+1, 1701+1, 256, 0.8).to(device) + MedHMP.train() + enc = torch.load("model/admission_pretrained.p").state_dict() + model_dict = MedHMP.state_dict() + state_dict = {k.split('enc.')[-1]:v for k,v in enc.items() if k.split('enc.')[-1] in model_dict.keys()} + MedHMP.state_dict().update(state_dict) + MedHMP.load_state_dict(state_dict, strict = False) + MedHMP, _ = adm_trainer(MedHMP, train, valid, test,EPOCHS, LR, BATCH, SEED, device , encoder = 'HMP', patience = 5) + file = open(PATH +str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv','w',encoding = 'gbk') + csv_w = csv.writer(file) + metrics = list(eval_metric_admission(test, MedHMP, device, 'HMP'))[:-1] + csv_w.writerow(["MedHMP"] + metrics) + file.close() + + globals().clear() + + + + + + + + + + + + + + + + + + +