--- a +++ b/stay_admission/admission_pretraining.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +from sklearn import preprocessing +import numpy as np +import model +import torch +from torch import nn, optim +from torch.utils.data import DataLoader +import torch.nn.functional as F +import sparse +from tqdm import tqdm +import pandas as pd +import pickle +import ast +import copy +from pytorch_metric_learning import losses +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from info_nce import InfoNCE + + +MODEL_PATH = "model/" + + + + +device = torch.device("cuda:0") + +data = pickle.load(open("data/admission_pretrain_data.p", 'rb')) + + +def collate_batch(batch_data): + + + + icd = torch.tensor([i[0] for i in batch_data]).to(torch.float32).to(device) + drug = torch.tensor([i[1] for i in batch_data]).to(torch.float32).to(device) + X = torch.tensor(np.array([np.stack(i[2], axis = 0) for i in batch_data])).to(torch.float32).to(device) + S = torch.tensor(np.array([np.stack(i[3], axis = 0) for i in batch_data])).to(torch.float32).to(device) + text = torch.stack([i[4] for i in batch_data]).to(torch.float32).to(device) + return [icd, drug, X, S, text] +data__ = DataLoader(data, batch_size = 4096, shuffle = True, collate_fn=collate_batch) + +EPOCHS = 300 + +HADM_AE = model.HADM_AE(vocab_size1 = 7686+1, vocab_size2 = 1701+1, d_model = 256, dropout=0.1, dropout_emb=0.1, length=48).to(device) +criterion_mcp = torch.nn.MSELoss() +criterion_cl = InfoNCE() +optimizer = torch.optim.Adam(HADM_AE.parameters(), + lr = 2e-5, + weight_decay = 1e-8) +enc = torch.load(MODEL_PATH + "/stay_pretrained.p").state_dict() +model_dict = HADM_AE.state_dict() + +state_dict = {k.replace("encoder.", "enc.ICU_Encoder."):v for k,v in enc.items() if k.replace("encoder.", "enc.ICU_Encoder.") in model_dict.keys()} +print(state_dict.keys()) + +HADM_AE.state_dict().update(state_dict) + +HADM_AE.load_state_dict(state_dict, strict=False) + + + +step = 0 + +for epoch in tqdm(range(EPOCHS)): + + loss = 0 + + + for batch_idx, batch_data in enumerate(data__): + + print(step) + + icd = batch_data[0] + drug = batch_data[1] + X = batch_data[2] + S = batch_data[3] + text = batch_data[4] + mask_icd = (torch.rand(size=(icd.shape)) > 0.15).to(device) + masked_icd = ~mask_icd*icd + nums_masked_icd = (masked_icd).sum(dim=1).unsqueeze(1) + unmasked_icd = icd*mask_icd + mask_icd[:, -1] = 1 + + + mask_drug = (torch.rand(size=(drug.shape)) > 0.15).to(device) + masked_drug = ~mask_drug*drug + nums_masked_drug = (~mask_drug*drug).sum(dim=1).unsqueeze(1) + unmasked_drug = drug*mask_drug + masked_drug[:,-1] = 1 + optimizer.zero_grad() + + + + doc_emb, doc_rep, x1, x1_rep, x2, x2_rep, mcm_x1_rep, mcm_x2_rep = HADM_AE(icd,drug,nums_masked_icd, nums_masked_drug, X,S, text, unmasked_icd,unmasked_drug) + mcm_loss = (criterion_mcp(mcm_x1_rep, masked_icd) + criterion_mcp(mcm_x2_rep, masked_drug))/2 + cl_loss = (criterion_cl(doc_emb, doc_rep) + criterion_cl(x1, x1_rep) + criterion_cl(x2, x2_rep))/3 + + train_loss = mcm_loss + 0.1*cl_loss + train_loss.backward() + + optimizer.step() + + loss += train_loss.item() + + step+=1 + print("one step mcm:", mcm_loss) + print("one step cl:", cl_loss) + print("Loss = ", loss) + + torch.save(HADM_AE,MODEL_PATH + 'admission_pretrained.p') \ No newline at end of file