Switch to side-by-side view

--- a
+++ b/stay_admission/admission_pretraining.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+from sklearn import preprocessing
+import numpy as np
+import model 
+import torch
+from torch import nn, optim
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+import sparse
+from tqdm import tqdm
+import pandas as pd
+import pickle
+import ast
+import copy
+from pytorch_metric_learning import losses
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
+from info_nce import InfoNCE
+
+
+MODEL_PATH = "model/"
+
+
+
+
+device = torch.device("cuda:0")
+
+data = pickle.load(open("data/admission_pretrain_data.p", 'rb'))
+
+
+def collate_batch(batch_data):
+
+
+
+        icd = torch.tensor([i[0] for i in batch_data]).to(torch.float32).to(device)
+        drug = torch.tensor([i[1] for i in batch_data]).to(torch.float32).to(device)
+        X = torch.tensor(np.array([np.stack(i[2], axis = 0) for i in batch_data])).to(torch.float32).to(device)
+        S = torch.tensor(np.array([np.stack(i[3], axis = 0) for i in batch_data])).to(torch.float32).to(device)
+        text = torch.stack([i[4] for i in batch_data]).to(torch.float32).to(device) 
+        return [icd, drug, X, S, text]
+data__ = DataLoader(data, batch_size = 4096, shuffle = True, collate_fn=collate_batch)
+
+EPOCHS = 300
+
+HADM_AE = model.HADM_AE(vocab_size1 = 7686+1, vocab_size2 = 1701+1,  d_model = 256, dropout=0.1, dropout_emb=0.1, length=48).to(device)
+criterion_mcp = torch.nn.MSELoss()
+criterion_cl = InfoNCE()
+optimizer = torch.optim.Adam(HADM_AE.parameters(),
+                              lr = 2e-5,
+                              weight_decay = 1e-8)
+enc = torch.load(MODEL_PATH + "/stay_pretrained.p").state_dict()
+model_dict = HADM_AE.state_dict()
+
+state_dict = {k.replace("encoder.", "enc.ICU_Encoder."):v for k,v in enc.items() if k.replace("encoder.", "enc.ICU_Encoder.") in model_dict.keys()}
+print(state_dict.keys())
+
+HADM_AE.state_dict().update(state_dict)
+
+HADM_AE.load_state_dict(state_dict, strict=False)
+
+
+
+step = 0
+
+for epoch in tqdm(range(EPOCHS)):
+    
+    loss = 0
+    
+    
+    for batch_idx, batch_data in enumerate(data__):
+        
+        print(step)
+
+        icd = batch_data[0]
+        drug = batch_data[1]
+        X = batch_data[2]
+        S = batch_data[3]
+        text = batch_data[4]
+        mask_icd = (torch.rand(size=(icd.shape)) > 0.15).to(device)
+        masked_icd = ~mask_icd*icd
+        nums_masked_icd = (masked_icd).sum(dim=1).unsqueeze(1)
+        unmasked_icd = icd*mask_icd
+        mask_icd[:, -1] = 1
+        
+        
+        mask_drug = (torch.rand(size=(drug.shape)) > 0.15).to(device)
+        masked_drug = ~mask_drug*drug
+        nums_masked_drug = (~mask_drug*drug).sum(dim=1).unsqueeze(1)
+        unmasked_drug = drug*mask_drug
+        masked_drug[:,-1] = 1
+        optimizer.zero_grad()
+            
+
+
+        doc_emb, doc_rep, x1, x1_rep, x2, x2_rep, mcm_x1_rep, mcm_x2_rep = HADM_AE(icd,drug,nums_masked_icd, nums_masked_drug,  X,S, text, unmasked_icd,unmasked_drug)
+        mcm_loss = (criterion_mcp(mcm_x1_rep, masked_icd) + criterion_mcp(mcm_x2_rep, masked_drug))/2
+        cl_loss =  (criterion_cl(doc_emb, doc_rep) + criterion_cl(x1, x1_rep) + criterion_cl(x2, x2_rep))/3
+ 
+        train_loss = mcm_loss + 0.1*cl_loss
+        train_loss.backward()
+
+        optimizer.step()
+
+        loss += train_loss.item()
+        
+        step+=1
+    print("one step mcm:", mcm_loss)
+    print("one step cl:", cl_loss)
+    print("Loss = ", loss)
+
+    torch.save(HADM_AE,MODEL_PATH + 'admission_pretrained.p')
\ No newline at end of file