Diff of /stay_admission/train.py [000000] .. [0218cb]

Switch to unified view

a b/stay_admission/train.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
import json 
5
import pickle
6
import pandas as pd
7
import numpy as np
8
import sparse
9
import torch
10
import model
11
from tqdm import tqdm
12
from torch import nn, optim
13
from torch.utils.data import DataLoader
14
import torch.nn.functional as F
15
import sklearn
16
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, cohen_kappa_score,precision_recall_curve, auc
17
import matplotlib.pyplot as plt
18
from baseline import *
19
from operations import *
20
from focal_loss.focal_loss import FocalLoss
21
import os
22
23
24
25
26
def eval_metric_stay(eval_set, model,  device, encoder = 'normal'):
27
    
28
    model.eval()
29
    criterion = torch.nn.BCELoss()
30
    with torch.no_grad():
31
        y_true = np.array([])
32
        y_pred = np.array([])
33
        y_score = np.array([])
34
        for i, batch_data in enumerate(eval_set):
35
            X = batch_data[0]                                                                      
36
            if encoder == 'HMP':                                                                   
37
                S = batch_data[1]                                                                                                                
38
            elif encoder == 'BERT':                                                                
39
                S = batch_data[1]                                                                  
40
                input_ids = batch_data[2]                                                          
41
                attention_mask = batch_data[3]                                                     
42
                token_type_ids = batch_data[4]                                                     
43
            labels = batch_data[-1]                                                                
44
            if encoder == 'normal':
45
                X2 = batch_data[1]           
46
                S = batch_data[2]                                                          
47
                # print(X.shape)                                                                   
48
                outputs = model(X, X2, S).squeeze().to(device)                                            
49
            elif encoder == 'HMP':                                                                 
50
                outputs = model(X,S).squeeze().to(device)                
51
            elif encoder == 'BERT':                                                                
52
                outputs = model(X,S,input_ids, token_type_ids, attention_mask).squeeze().to(device)
53
            score = outputs
54
            score = score.data.cpu().numpy()
55
            labels = labels.data.cpu().numpy()
56
57
            pred = np.where(score >= 0.5, 1.0, 0.0)
58
59
            if labels.shape[0] != 1:
60
                
61
                y_true = np.concatenate((y_true, labels))
62
                y_pred = np.concatenate((y_pred, pred))
63
                y_score = np.concatenate((y_score, score))
64
            else:
65
                y_true = np.array(list(y_true) + list(labels))
66
                y_pred = np.array(list(y_pred) + list([pred]))
67
                y_score = np.array(list(y_score) + list([score]))
68
        accuary = accuracy_score(y_true, y_pred)
69
        precision = precision_score(y_true, y_pred)
70
        recall = recall_score(y_true, y_pred)
71
        f1 = f1_score(y_true, y_pred)
72
        roc_auc = roc_auc_score(y_true, y_score)
73
        lr_precision, lr_recall, _ = precision_recall_curve(y_true, y_score)
74
        pr_auc = auc(lr_recall, lr_precision)
75
        kappa = cohen_kappa_score(y_true, y_pred)
76
        loss = criterion(torch.from_numpy(y_true), torch.from_numpy(y_score))
77
78
    return  f1, roc_auc, pr_auc, kappa, loss
79
80
def eval_metric_admission(eval_set, model,  device, encoder = 'normal'):
81
    
82
    model.eval()
83
    criterion = torch.nn.BCELoss()
84
    with torch.no_grad():
85
        y_true = np.array([])
86
        y_pred = np.array([])
87
        y_score = np.array([])
88
        for i, batch_data in enumerate(eval_set):
89
            icd = batch_data[0]
90
            drug = batch_data[1]
91
            X = batch_data[2]
92
            S = batch_data[3]
93
            input_ids = batch_data[4]
94
            attention_mask = batch_data[5]
95
            token_type_ids = batch_data[6]
96
            labels = batch_data[-1]
97
            outputs = model(icd, drug,X,S,input_ids, attention_mask, token_type_ids).squeeze().to(device)
98
            score = outputs
99
            score = score.data.cpu().numpy()
100
            labels = labels.data.cpu().numpy()
101
                # pred = torch.tensor([1 if x > 0.5 else 0 for x in score])
102
    
103
            pred = np.where(score >= 0.5, 1.0, 0.0)
104
            y_true = np.concatenate((y_true, labels))
105
            y_pred = np.concatenate((y_pred, pred))
106
            y_score = np.concatenate((y_score, score))
107
        accuary = accuracy_score(y_true, y_pred)
108
        precision = precision_score(y_true, y_pred)
109
        recall = recall_score(y_true, y_pred)
110
        f1 = f1_score(y_true, y_pred)
111
        roc_auc = roc_auc_score(y_true, y_score)
112
        lr_precision, lr_recall, _ = precision_recall_curve(y_true, y_score)
113
        pr_auc = auc(lr_recall, lr_precision)
114
        kappa = cohen_kappa_score(y_true, y_pred)
115
        loss = criterion(torch.from_numpy(y_true), torch.from_numpy(y_score))
116
117
    return  f1, roc_auc, pr_auc, kappa, loss
118
119
120
def icu_trainer(model, train, valid, test, epoch, learn_rate, batch_size, seed, device, encoder = 'normal', patience = 3):
121
    
122
    torch.manual_seed(seed)
123
    
124
    model.train()
125
    aupr_list = []
126
127
    criterion = torch.nn.BCELoss()
128
    optimizer = torch.optim.SGD(model.parameters(),
129
                                momentum=0.9,
130
                                  lr = learn_rate,
131
                                  weight_decay = 1e-2)
132
    
133
    
134
    
135
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 
136
    f1, roc_auc, pr_auc, kappa, valid_loss = eval_metric_stay(valid, model, device, encoder)
137
    best_dev = valid_loss
138
    best_epoc = 0
139
    model.train()
140
    from datetime import datetime
141
    dt = datetime.now()
142
    torch.save(model, "saved_model/hmp_model" + str(dt) +".p")
143
    best_name = "saved_model/hmp_model" + str(dt) +".p"
144
145
    for epoch in tqdm(range(epoch)):
146
        
147
        loss = 0
148
        model.train()
149
        for batch_idx, batch_data in enumerate(train):
150
            
151
                X = batch_data[0]
152
                if encoder == 'HMP':
153
                    S = batch_data[1]
154
                elif encoder == 'BERT':
155
                    S = batch_data[1]
156
                    input_ids = batch_data[2]
157
                    attention_mask = batch_data[3]
158
                    token_type_ids = batch_data[4]
159
                label = batch_data[-1]
160
                optimizer.zero_grad()
161
                if encoder == 'normal':
162
                    outputs = model(X).squeeze().to(device)
163
                elif encoder == 'HMP':
164
                    outputs = model(X,S).squeeze().to(device)
165
                elif encoder == 'BERT':
166
                    outputs = model(X,S,input_ids, attention_mask, token_type_ids).squeeze().to(device)
167
                train_loss = criterion(outputs, label)
168
                train_loss.backward()
169
                optimizer.step()
170
                loss += train_loss.item()
171
        scheduler.step()
172
        print("Training Loss = ", loss)
173
        
174
        
175
    
176
        model.eval()
177
        
178
179
        f1, roc_auc, pr_auc, kappa, valid_loss = eval_metric_stay(valid, model, device, encoder)
180
        print("Dev = ", valid_loss)
181
        
182
        dt = datetime.now()
183
        if best_dev > valid_loss:
184
            best_dev = valid_loss
185
            best_epoc = epoch
186
            torch.save(model, "saved_model/icu_model" + str(dt) +".p")
187
            os.remove(best_name)
188
            best_name = "saved_model/icu_model" + str(dt) +".p"
189
        if epoch - best_epoc == patience:
190
            break
191
    model.train()
192
    model = torch.load(best_name)
193
    os.remove(best_name)
194
    return model, aupr_list
195
196
197
198
def adm_trainer(model, train, valid, test, epoch, learn_rate, batch_size, seed, device, encoder = 'normal', patience = 3):
199
    
200
    torch.manual_seed(seed)
201
    
202
    model.train()
203
    aupr_list = []
204
205
    criterion = torch.nn.BCELoss()
206
    optimizer = torch.optim.SGD(model.parameters(),
207
                                momentum=0.9,
208
                                  lr = learn_rate,
209
                                  weight_decay = 1e-2)
210
    
211
    
212
    
213
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) 
214
    f1, roc_auc, pr_auc, kappa, valid_loss = eval_metric_admission(valid, model, device, encoder)
215
    best_dev = valid_loss
216
    best_epoc = 0
217
    model.train()
218
    
219
    for epoch in tqdm(range(epoch)):
220
        
221
        loss = 0
222
        
223
        for batch_idx, batch_data in enumerate(train):
224
                icd = batch_data[0]
225
                drug = batch_data[1]
226
                X = batch_data[2]
227
                S = batch_data[3]
228
                input_ids = batch_data[4]
229
                attention_mask = batch_data[5]
230
                token_type_ids = batch_data[6]
231
                label = batch_data[-1]
232
                optimizer.zero_grad()
233
                outputs = model(icd, drug,X,S,input_ids, attention_mask, token_type_ids).squeeze().to(device)
234
                train_loss = criterion(outputs, label)
235
                train_loss.backward()
236
                optimizer.step()
237
                loss += train_loss.item()
238
        scheduler.step()
239
        print("Training Loss = ", loss)
240
        
241
        
242
    
243
        model.eval()
244
        
245
246
        f1, roc_auc, pr_auc, kappa, valid_loss = eval_metric_admission(valid, model, device, encoder)
247
        print("Dev = ", valid_loss)
248
        if best_dev > valid_loss:
249
            best_dev = valid_loss
250
            best_epoc = epoch
251
            torch.save(model, "saved_model/adm_model.p")
252
        if epoch - best_epoc == patience:
253
            model = torch.load("saved_model/adm_model.p")
254
            break
255
        model.train()
256
    model = torch.load("saved_model/adm_model.p")
257
    return model, aupr_list
258
259
260
261
   
262
263
    
264