a b/stay_admission/admission_downstream.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
5
import json 
6
import pickle
7
import pandas as pd
8
import numpy as np
9
import sparse
10
import torch
11
import model
12
from tqdm import tqdm
13
from torch import nn, optim
14
from torch.utils.data import DataLoader
15
import torch.nn.functional as F
16
import sklearn
17
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, cohen_kappa_score, f1_score
18
import matplotlib.pyplot as plt
19
from baseline import *
20
from operations import *
21
from train import *
22
from model import *
23
import sys
24
import csv
25
import os
26
27
28
29
30
31
32
print(sys.argv[2:])
33
EPOCHS,LR, BATCH, SEED, TASK, DEVICE = sys.argv[2:]
34
EPOCHS,LR, BATCH, SEED, TASK, DEVICE = int(EPOCHS), float(LR), int(BATCH), int(SEED), str(TASK), str(DEVICE)
35
36
PATH = "hmp_results/" + TASK + '/'
37
38
if str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv' in os.listdir(PATH):
39
    print("conducted experiments")
40
else:
41
42
43
    device = torch.device(DEVICE)
44
45
    data = pickle.load(open("data/hmp_admission.p", 'rb'))
46
    
47
48
    split_mark = int(len(data)*0.8), int(len(data)*0.9)
49
    
50
    
51
52
    def collate_batch(batch_data):
53
54
55
56
            icd = torch.tensor([i[0] for i in batch_data]).to(torch.float32).to(device)
57
            drug = torch.tensor([i[1] for i in batch_data]).to(torch.float32).to(device)
58
            X = torch.tensor(np.array([np.stack(i[2], axis = 0) for i in batch_data])).to(torch.float32).to(device)
59
            S = torch.tensor(np.array([np.stack(i[3], axis = 0) for i in batch_data])).to(torch.float32).to(device)
60
            input_ids =  torch.stack([i[4] for i in batch_data]).to(device)
61
            attention_mask = torch.stack([i[5] for i in batch_data]).to(device)
62
            token_type_ids = torch.stack([i[6] for i in batch_data]).to(device)
63
            label = torch.tensor(np.array([i[-1] for i in batch_data])).to(torch.float32).to(device) 
64
            return [icd , drug, X, S, input_ids, attention_mask, token_type_ids, label]
65
    def collate_batch_ts(batch_data):
66
67
68
69
            X = torch.tensor(np.array([i[0] for i in batch_data])).to(torch.float32).to(device)
70
            label = torch.tensor(np.array([i[4] for i in batch_data])).to(torch.float32).to(device) 
71
            return [X, label]
72
        
73
74
    
75
    
76
    
77
    
78
    
79
    
80
    
81
    
82
    # multimodal encoder evaluation
83
    
84
    test = DataLoader(data[split_mark[1]:], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
85
    train = DataLoader(data[:split_mark[0]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
86
    valid = DataLoader(data[split_mark[0]:split_mark[1]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
87
    
88
    
89
    
90
    MedHMP = HADM_CLS(7686+1, 1701+1, 256, 0.8).to(device)
91
    MedHMP.train()
92
    enc = torch.load("model/admission_pretrained.p").state_dict()
93
    model_dict = MedHMP.state_dict()
94
    state_dict = {k.split('enc.')[-1]:v for k,v in enc.items() if k.split('enc.')[-1] in model_dict.keys()}
95
    MedHMP.state_dict().update(state_dict)
96
    MedHMP.load_state_dict(state_dict, strict = False)
97
    MedHMP, _ = adm_trainer(MedHMP, train, valid, test,EPOCHS, LR, BATCH, SEED, device , encoder = 'HMP', patience = 5)
98
    file = open(PATH +str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv','w',encoding  = 'gbk')
99
    csv_w = csv.writer(file)
100
    metrics = list(eval_metric_admission(test, MedHMP,  device, 'HMP'))[:-1]
101
    csv_w.writerow(["MedHMP"] + metrics)
102
    file.close()
103
    
104
    globals().clear()
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123