a b/stay_admission/admission_pretraining.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
5
from sklearn import preprocessing
6
import numpy as np
7
import model 
8
import torch
9
from torch import nn, optim
10
from torch.utils.data import DataLoader
11
import torch.nn.functional as F
12
import sparse
13
from tqdm import tqdm
14
import pandas as pd
15
import pickle
16
import ast
17
import copy
18
from pytorch_metric_learning import losses
19
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
20
from info_nce import InfoNCE
21
22
23
MODEL_PATH = "model/"
24
25
26
27
28
device = torch.device("cuda:0")
29
30
data = pickle.load(open("data/admission_pretrain_data.p", 'rb'))
31
32
33
def collate_batch(batch_data):
34
35
36
37
        icd = torch.tensor([i[0] for i in batch_data]).to(torch.float32).to(device)
38
        drug = torch.tensor([i[1] for i in batch_data]).to(torch.float32).to(device)
39
        X = torch.tensor(np.array([np.stack(i[2], axis = 0) for i in batch_data])).to(torch.float32).to(device)
40
        S = torch.tensor(np.array([np.stack(i[3], axis = 0) for i in batch_data])).to(torch.float32).to(device)
41
        text = torch.stack([i[4] for i in batch_data]).to(torch.float32).to(device) 
42
        return [icd, drug, X, S, text]
43
data__ = DataLoader(data, batch_size = 4096, shuffle = True, collate_fn=collate_batch)
44
45
EPOCHS = 300
46
47
HADM_AE = model.HADM_AE(vocab_size1 = 7686+1, vocab_size2 = 1701+1,  d_model = 256, dropout=0.1, dropout_emb=0.1, length=48).to(device)
48
criterion_mcp = torch.nn.MSELoss()
49
criterion_cl = InfoNCE()
50
optimizer = torch.optim.Adam(HADM_AE.parameters(),
51
                              lr = 2e-5,
52
                              weight_decay = 1e-8)
53
enc = torch.load(MODEL_PATH + "/stay_pretrained.p").state_dict()
54
model_dict = HADM_AE.state_dict()
55
56
state_dict = {k.replace("encoder.", "enc.ICU_Encoder."):v for k,v in enc.items() if k.replace("encoder.", "enc.ICU_Encoder.") in model_dict.keys()}
57
print(state_dict.keys())
58
59
HADM_AE.state_dict().update(state_dict)
60
61
HADM_AE.load_state_dict(state_dict, strict=False)
62
63
64
65
step = 0
66
67
for epoch in tqdm(range(EPOCHS)):
68
    
69
    loss = 0
70
    
71
    
72
    for batch_idx, batch_data in enumerate(data__):
73
        
74
        print(step)
75
76
        icd = batch_data[0]
77
        drug = batch_data[1]
78
        X = batch_data[2]
79
        S = batch_data[3]
80
        text = batch_data[4]
81
        mask_icd = (torch.rand(size=(icd.shape)) > 0.15).to(device)
82
        masked_icd = ~mask_icd*icd
83
        nums_masked_icd = (masked_icd).sum(dim=1).unsqueeze(1)
84
        unmasked_icd = icd*mask_icd
85
        mask_icd[:, -1] = 1
86
        
87
        
88
        mask_drug = (torch.rand(size=(drug.shape)) > 0.15).to(device)
89
        masked_drug = ~mask_drug*drug
90
        nums_masked_drug = (~mask_drug*drug).sum(dim=1).unsqueeze(1)
91
        unmasked_drug = drug*mask_drug
92
        masked_drug[:,-1] = 1
93
        optimizer.zero_grad()
94
            
95
96
97
        doc_emb, doc_rep, x1, x1_rep, x2, x2_rep, mcm_x1_rep, mcm_x2_rep = HADM_AE(icd,drug,nums_masked_icd, nums_masked_drug,  X,S, text, unmasked_icd,unmasked_drug)
98
        mcm_loss = (criterion_mcp(mcm_x1_rep, masked_icd) + criterion_mcp(mcm_x2_rep, masked_drug))/2
99
        cl_loss =  (criterion_cl(doc_emb, doc_rep) + criterion_cl(x1, x1_rep) + criterion_cl(x2, x2_rep))/3
100
 
101
        train_loss = mcm_loss + 0.1*cl_loss
102
        train_loss.backward()
103
104
        optimizer.step()
105
106
        loss += train_loss.item()
107
        
108
        step+=1
109
    print("one step mcm:", mcm_loss)
110
    print("one step cl:", cl_loss)
111
    print("Loss = ", loss)
112
113
    torch.save(HADM_AE,MODEL_PATH + 'admission_pretrained.p')