Switch to side-by-side view

--- a
+++ b/stay_admission/stay_downstream.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import json 
+import pickle
+import pandas as pd
+import numpy as np
+import sparse
+import torch
+import model
+from tqdm import tqdm
+from torch import nn, optim
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+import sklearn
+from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, cohen_kappa_score, f1_score
+import matplotlib.pyplot as plt
+from baseline import *
+from operations import *
+from train import *
+from model import *
+import sys
+import csv
+import os
+
+
+print(sys.argv[2:])
+
+
+EPOCHS,LR, BATCH, TASK, DEVICE = sys.argv[2:]
+EPOCHS,LR, BATCH, TASK, DEVICE = int(EPOCHS), float(LR), int(BATCH),  str(TASK), str(DEVICE)
+
+PATH = "hmp_icu" + TASK + '/'
+
+if str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv' in os.listdir(PATH):
+    print("conducted experiments")
+else:
+
+
+    device = torch.device(DEVICE)
+
+    data = pickle.load(open(TASK + "_pred.p", 'rb'))
+    
+
+    split_mark = int(len(data)*0.8), int(len(data)*0.9)
+    
+    
+
+    def collate_batch(batch_data):
+
+
+
+            X = torch.tensor(np.array([i[0] for i in batch_data])).to(torch.float32).to(device)
+            S = torch.tensor(np.array([i[1] for i in batch_data])).to(torch.float32).to(device)
+            label = torch.tensor(np.array([i[2] for i in batch_data])).to(torch.float32).to(device) 
+            return [X, S,  label]
+    def collate_batch_ts(batch_data):
+
+
+
+            X = torch.tensor(np.array([i[0] for i in batch_data])).to(torch.float32).to(device)
+            label = torch.tensor(np.array([i[4] for i in batch_data])).to(torch.float32).to(device) 
+            return [X, label]
+        
+        
+
+    
+    
+    
+    
+    
+    # multimodal encoder evaluation
+    
+    test = DataLoader(data[split_mark[1]:], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
+    train = DataLoader(data[:split_mark[0]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
+    valid = DataLoader(data[split_mark[0]:split_mark[1]], batch_size = BATCH, shuffle = True, collate_fn=collate_batch)
+    
+    
+    
+    MedHMP = HMP(48, 1318, 256, 0.2).to(device)
+    MedHMP.train()
+    enc = torch.load("model/admission_pretrained.p").state_dict()
+    model_dict = MedHMP.state_dict()
+    state_dict = {k.split('enc.')[-1]:v for k,v in enc.items() if k.split('enc.')[-1] in model_dict.keys()}
+    MedHMP.state_dict().update(state_dict)
+    MedHMP.load_state_dict(state_dict, strict = False)
+    MedHMP, _ = hmp_trainer(MedHMP, train, valid, test,EPOCHS, LR, BATCH,  device , encoder = 'HMP', patience = 5)
+    
+    
+    file = open(PATH +str(EPOCHS) + '_' + str(LR) + '_' + str(BATCH)+ '_' +'.csv','w',encoding  = 'gbk')
+    csv_w = csv.writer(file)
+    
+    
+
+
+        
+    metrics = list(eval_metric(test, MedHMP,  device, 'MedHMP'))[:-1]
+    csv_w.writerow(["MedHMP"] + metrics)
+    
+
+    
+    
+  
+    file.close()
+    
+    globals().clear()