Diff of /aggmap/aggmodel/cbks.py [000000] .. [9e8054]

Switch to unified view

a b/aggmap/aggmodel/cbks.py
1
from sklearn.metrics import roc_auc_score, precision_recall_curve
2
from sklearn.metrics import auc as calculate_auc
3
from sklearn.metrics import mean_squared_error
4
from sklearn.metrics import accuracy_score
5
import tensorflow as tf
6
import os
7
import numpy as np
8
9
10
from scipy.stats.stats import pearsonr
11
def r2_score(x,y):
12
    pcc, _ = pearsonr(x,y)
13
    return pcc**2
14
15
def prc_auc_score(y_true, y_score):
16
    precision, recall, threshold  = precision_recall_curve(y_true, y_score) #PRC_AUC
17
    auc = calculate_auc(recall, precision)
18
    return auc
19
20
'''
21
for early-stopping techniques in regression and classification task
22
'''
23
######## Regression ###############################
24
25
26
class Reg_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback):
27
28
    def __init__(self, train_data, valid_data, MASK = -1, patience=5, criteria = 'val_loss', verbose = 0):
29
        super(Reg_EarlyStoppingAndPerformance, self).__init__()
30
        
31
        assert criteria in ['val_loss', 'val_r2'], 'not support %s ! only %s' % (criteria, ['val_loss', 'val_r2'])
32
        self.x, self.y  = train_data
33
        self.x_val, self.y_val = valid_data
34
        
35
        self.history = {'loss':[],
36
                        'val_loss':[],
37
                        
38
                        'rmse':[],
39
                        'val_rmse':[],
40
                        
41
                        'r2':[],
42
                        'val_r2':[],
43
                        
44
                        'epoch':[]}
45
        self.MASK = MASK
46
        self.patience = patience
47
        # best_weights to store the weights at which the minimum loss occurs.
48
        self.best_weights = None
49
        self.criteria = criteria
50
        self.best_epoch = 0
51
        self.verbose = verbose
52
        
53
    def rmse(self, y_true, y_pred):
54
55
        N_classes = y_pred.shape[1]
56
        rmses = []
57
        for i in range(N_classes):
58
            y_pred_one_class = y_pred[:,i]
59
            y_true_one_class = y_true[:, i]
60
            mask = ~(y_true_one_class == self.MASK)
61
            mse = mean_squared_error(y_true_one_class[mask], y_pred_one_class[mask])
62
            rmse = np.sqrt(mse)
63
            rmses.append(rmse)
64
        return rmses   
65
    
66
    
67
    def r2(self, y_true, y_pred):
68
        N_classes = y_pred.shape[1]
69
        r2s = []
70
        for i in range(N_classes):
71
            y_pred_one_class = y_pred[:,i]
72
            y_true_one_class = y_true[:, i]
73
            mask = ~(y_true_one_class == self.MASK)
74
            r2 = r2_score(y_true_one_class[mask], y_pred_one_class[mask])
75
            r2s.append(r2)
76
        return r2s   
77
    
78
        
79
    def on_train_begin(self, logs=None):
80
        # The number of epoch it has waited when loss is no longer minimum.
81
        self.wait = 0
82
        # The epoch the training stops at.
83
        self.stopped_epoch = 0
84
        # Initialize the best as infinity.
85
        if self.criteria == 'val_loss':
86
            self.best = np.Inf  
87
        else:
88
            self.best = -np.Inf
89
            
90
        
91
        
92
 
93
        
94
    def on_epoch_end(self, epoch, logs={}):
95
        
96
        y_pred = self.model.predict(self.x, verbose=self.verbose)
97
        rmse_list = self.rmse(self.y, y_pred)
98
        rmse_mean = np.nanmean(rmse_list)
99
        
100
        r2_list = self.r2(self.y, y_pred) 
101
        r2_mean = np.nanmean(r2_list)
102
        
103
        
104
        y_pred_val = self.model.predict(self.x_val, verbose=self.verbose)
105
        rmse_list_val = self.rmse(self.y_val, y_pred_val)        
106
        rmse_mean_val = np.nanmean(rmse_list_val)
107
        
108
        r2_list_val = self.r2(self.y_val, y_pred_val)       
109
        r2_mean_val = np.nanmean(r2_list_val)        
110
        
111
        self.history['loss'].append(logs.get('loss'))
112
        self.history['val_loss'].append(logs.get('val_loss'))
113
        
114
        self.history['rmse'].append(rmse_mean)
115
        self.history['val_rmse'].append(rmse_mean_val)
116
        
117
        self.history['r2'].append(r2_mean)
118
        self.history['val_r2'].append(r2_mean_val)        
119
        
120
        self.history['epoch'].append(epoch)
121
        
122
        
123
        # logs is a dictionary
124
        eph = str(epoch+1).zfill(4)   
125
        loss = '{0:.4f}'.format((logs.get('loss')))
126
        val_loss = '{0:.4f}'.format((logs.get('val_loss')))
127
        rmse = '{0:.4f}'.format(rmse_mean)
128
        rmse_val = '{0:.4f}'.format(rmse_mean_val)
129
        r2_mean = '{0:.4f}'.format(r2_mean)
130
        r2_mean_val = '{0:.4f}'.format(r2_mean_val)
131
        
132
        if self.verbose:
133
            print('\repoch: %s, loss: %s - val_loss: %s; rmse: %s - rmse_val: %s;  r2: %s - r2_val: %s' % (eph,
134
                                                                                                           loss, val_loss, 
135
                                                                                                           rmse,rmse_val,
136
                                                                                                           r2_mean,r2_mean_val),
137
                  end=100*' '+'\n')
138
139
140
        if self.criteria == 'val_loss':
141
            current = logs.get(self.criteria)
142
            if current <= self.best:
143
                self.best = current
144
                self.wait = 0
145
                # Record the best weights if current results is better (less).
146
                self.best_weights = self.model.get_weights()
147
                self.best_epoch = epoch
148
149
            else:
150
                self.wait += 1
151
                if self.wait >= self.patience:
152
                    self.stopped_epoch = epoch
153
                    self.model.stop_training = True
154
                    print('\nRestoring model weights from the end of the best epoch.')
155
                    self.model.set_weights(self.best_weights)    
156
                    
157
        else:
158
            current = np.nanmean(r2_list_val)
159
            
160
            if current >= self.best:
161
                self.best = current
162
                self.wait = 0
163
                # Record the best weights if current results is better (less).
164
                self.best_weights = self.model.get_weights()
165
                self.best_epoch = epoch
166
167
            else:
168
                self.wait += 1
169
                if self.wait >= self.patience:
170
                    self.stopped_epoch = epoch
171
                    self.model.stop_training = True
172
                    print('\nRestoring model weights from the end of the best epoch.')
173
                    self.model.set_weights(self.best_weights)              
174
    
175
    def on_train_end(self, logs=None):
176
        self.model.set_weights(self.best_weights)
177
        if self.stopped_epoch > 0:
178
            print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1))
179
180
        
181
        
182
    def evaluate(self, testX, testY):
183
        """evalulate, return rmse and r2"""
184
        y_pred = self.model.predict(testX, verbose=self.verbose)
185
        rmse_list = self.rmse(testY, y_pred)
186
        r2_list = self.r2(testY, y_pred)
187
        return rmse_list, r2_list       
188
189
190
191
    
192
    
193
######## classification ###############################
194
195
class CLA_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback):
196
197
    def __init__(self, train_data, valid_data, MASK = -1, patience=5, criteria = 'val_loss', metric = 'ROC', last_avf = None, verbose = 0):
198
        super(CLA_EarlyStoppingAndPerformance, self).__init__()
199
        
200
        sp = ['val_loss', 'val_metric']
201
        assert criteria in sp, 'not support %s ! only %s' % (criteria, sp)
202
        ms = ['ROC', 'PRC', 'ACC']
203
        assert metric in ms, 'not support %s ! only %s' % (metric, ms)        
204
        ms_dict = {'ROC':'roc_auc', 'PRC':'prc_auc', 'ACC':'accuracy'}
205
        
206
        metric = ms_dict[metric]
207
        val_metric = 'val_%s' % metric
208
        self.metric = metric
209
        self.val_metric = val_metric
210
        
211
        self.x, self.y  = train_data
212
        self.x_val, self.y_val = valid_data
213
        self.last_avf = last_avf
214
        
215
        
216
        
217
        self.history = {'loss':[],
218
                        'val_loss':[],
219
                        self.metric:[],
220
                        self.val_metric:[],
221
                        'epoch':[]}
222
        self.MASK = MASK
223
        self.patience = patience
224
        # best_weights to store the weights at which the minimum loss occurs.
225
        self.best_weights = None
226
        self.criteria = criteria
227
228
        
229
        self.best_epoch = 0
230
        self.verbose = verbose
231
        
232
    def sigmoid(self, x):
233
        s = 1/(1+np.exp(-x))
234
        return s
235
236
    
237
    def roc_auc(self, y_true, y_pred):
238
        if self.last_avf == None:
239
            y_pred_logits = self.sigmoid(y_pred)
240
        else:
241
            y_pred_logits = y_pred
242
            
243
        N_classes = y_pred_logits.shape[1]
244
245
        aucs = []
246
        for i in range(N_classes):
247
            y_pred_one_class = y_pred_logits[:,i]
248
            y_true_one_class = y_true[:, i]
249
            mask = ~(y_true_one_class == self.MASK)
250
            try:
251
                if self.metric == 'roc_auc':
252
                    auc = roc_auc_score(y_true_one_class[mask], y_pred_one_class[mask], average='weighted') #ROC_AUC
253
                elif self.metric == 'prc_auc': 
254
                    auc = prc_auc_score(y_true_one_class[mask], y_pred_one_class[mask]) #PRC_AUC
255
                elif self.metric == 'accuracy':
256
                    auc = accuracy_score(y_true_one_class[mask], np.round(y_pred_one_class[mask])) #ACC
257
            except:
258
                auc = np.nan
259
            aucs.append(auc)
260
        return aucs  
261
    
262
        
263
        
264
    def on_train_begin(self, logs=None):
265
        # The number of epoch it has waited when loss is no longer minimum.
266
        self.wait = 0
267
        # The epoch the training stops at.
268
        self.stopped_epoch = 0
269
        # Initialize the best as infinity.
270
        if self.criteria == 'val_loss':
271
            self.best = np.Inf  
272
        else:
273
            self.best = -np.Inf
274
275
            
276
    def on_epoch_end(self, epoch, logs={}):
277
        
278
        y_pred = self.model.predict(self.x, verbose = self.verbose)
279
        roc_list = self.roc_auc(self.y, y_pred)
280
        roc_mean = np.nanmean(roc_list)
281
        
282
        y_pred_val = self.model.predict(self.x_val, verbose = self.verbose)
283
        roc_val_list = self.roc_auc(self.y_val, y_pred_val)        
284
        roc_val_mean = np.nanmean(roc_val_list)
285
        
286
        self.history['loss'].append(logs.get('loss'))
287
        self.history['val_loss'].append(logs.get('val_loss'))
288
        self.history[self.metric].append(roc_mean)
289
        self.history[self.val_metric].append(roc_val_mean)
290
        self.history['epoch'].append(epoch)
291
        
292
        eph = str(epoch+1).zfill(4)        
293
        loss = '{0:.4f}'.format((logs.get('loss')))
294
        val_loss = '{0:.4f}'.format((logs.get('val_loss')))
295
        auc = '{0:.4f}'.format(roc_mean)
296
        auc_val = '{0:.4f}'.format(roc_val_mean)    
297
        
298
        if self.verbose:
299
            print('\repoch: %s, loss: %s - val_loss: %s; %s: %s - %s: %s' % (eph,
300
                                                                             loss, 
301
                                                                             val_loss, 
302
                                                                             self.metric,
303
                                                                             auc,
304
                                                                             self.val_metric,
305
                                                                             auc_val), end=100*' '+'\n')
306
307
        if self.criteria == 'val_loss':
308
            current = logs.get(self.criteria)
309
            if current <= self.best:
310
                self.best = current
311
                self.wait = 0
312
                # Record the best weights if current results is better (less).
313
                self.best_weights = self.model.get_weights()
314
                self.best_epoch = epoch
315
316
            else:
317
                self.wait += 1
318
                if self.wait >= self.patience:
319
                    self.stopped_epoch = epoch
320
                    self.model.stop_training = True
321
                    print('\nRestoring model weights from the end of the best epoch.')
322
                    self.model.set_weights(self.best_weights)    
323
                    
324
        else:
325
            current = roc_val_mean
326
            if current >= self.best:
327
                self.best = current
328
                self.wait = 0
329
                # Record the best weights if current results is better (less).
330
                self.best_weights = self.model.get_weights()
331
                self.best_epoch = epoch
332
333
            else:
334
                self.wait += 1
335
                if self.wait >= self.patience:
336
                    self.stopped_epoch = epoch
337
                    self.model.stop_training = True
338
                    print('\nRestoring model weights from the end of the best epoch.')
339
                    self.model.set_weights(self.best_weights)              
340
    
341
    def on_train_end(self, logs=None):
342
        self.model.set_weights(self.best_weights)
343
        if self.stopped_epoch > 0:
344
            print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1))
345
346
        
347
    def evaluate(self, testX, testY):
348
        
349
        y_pred = self.model.predict(testX, verbose = self.verbose)
350
        roc_list = self.roc_auc(testY, y_pred)
351
        return roc_list            
352
353