--- a +++ b/aggmap/aggmodel/cbks.py @@ -0,0 +1,353 @@ +from sklearn.metrics import roc_auc_score, precision_recall_curve +from sklearn.metrics import auc as calculate_auc +from sklearn.metrics import mean_squared_error +from sklearn.metrics import accuracy_score +import tensorflow as tf +import os +import numpy as np + + +from scipy.stats.stats import pearsonr +def r2_score(x,y): + pcc, _ = pearsonr(x,y) + return pcc**2 + +def prc_auc_score(y_true, y_score): + precision, recall, threshold = precision_recall_curve(y_true, y_score) #PRC_AUC + auc = calculate_auc(recall, precision) + return auc + +''' +for early-stopping techniques in regression and classification task +''' +######## Regression ############################### + + +class Reg_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback): + + def __init__(self, train_data, valid_data, MASK = -1, patience=5, criteria = 'val_loss', verbose = 0): + super(Reg_EarlyStoppingAndPerformance, self).__init__() + + assert criteria in ['val_loss', 'val_r2'], 'not support %s ! only %s' % (criteria, ['val_loss', 'val_r2']) + self.x, self.y = train_data + self.x_val, self.y_val = valid_data + + self.history = {'loss':[], + 'val_loss':[], + + 'rmse':[], + 'val_rmse':[], + + 'r2':[], + 'val_r2':[], + + 'epoch':[]} + self.MASK = MASK + self.patience = patience + # best_weights to store the weights at which the minimum loss occurs. + self.best_weights = None + self.criteria = criteria + self.best_epoch = 0 + self.verbose = verbose + + def rmse(self, y_true, y_pred): + + N_classes = y_pred.shape[1] + rmses = [] + for i in range(N_classes): + y_pred_one_class = y_pred[:,i] + y_true_one_class = y_true[:, i] + mask = ~(y_true_one_class == self.MASK) + mse = mean_squared_error(y_true_one_class[mask], y_pred_one_class[mask]) + rmse = np.sqrt(mse) + rmses.append(rmse) + return rmses + + + def r2(self, y_true, y_pred): + N_classes = y_pred.shape[1] + r2s = [] + for i in range(N_classes): + y_pred_one_class = y_pred[:,i] + y_true_one_class = y_true[:, i] + mask = ~(y_true_one_class == self.MASK) + r2 = r2_score(y_true_one_class[mask], y_pred_one_class[mask]) + r2s.append(r2) + return r2s + + + def on_train_begin(self, logs=None): + # The number of epoch it has waited when loss is no longer minimum. + self.wait = 0 + # The epoch the training stops at. + self.stopped_epoch = 0 + # Initialize the best as infinity. + if self.criteria == 'val_loss': + self.best = np.Inf + else: + self.best = -np.Inf + + + + + + def on_epoch_end(self, epoch, logs={}): + + y_pred = self.model.predict(self.x, verbose=self.verbose) + rmse_list = self.rmse(self.y, y_pred) + rmse_mean = np.nanmean(rmse_list) + + r2_list = self.r2(self.y, y_pred) + r2_mean = np.nanmean(r2_list) + + + y_pred_val = self.model.predict(self.x_val, verbose=self.verbose) + rmse_list_val = self.rmse(self.y_val, y_pred_val) + rmse_mean_val = np.nanmean(rmse_list_val) + + r2_list_val = self.r2(self.y_val, y_pred_val) + r2_mean_val = np.nanmean(r2_list_val) + + self.history['loss'].append(logs.get('loss')) + self.history['val_loss'].append(logs.get('val_loss')) + + self.history['rmse'].append(rmse_mean) + self.history['val_rmse'].append(rmse_mean_val) + + self.history['r2'].append(r2_mean) + self.history['val_r2'].append(r2_mean_val) + + self.history['epoch'].append(epoch) + + + # logs is a dictionary + eph = str(epoch+1).zfill(4) + loss = '{0:.4f}'.format((logs.get('loss'))) + val_loss = '{0:.4f}'.format((logs.get('val_loss'))) + rmse = '{0:.4f}'.format(rmse_mean) + rmse_val = '{0:.4f}'.format(rmse_mean_val) + r2_mean = '{0:.4f}'.format(r2_mean) + r2_mean_val = '{0:.4f}'.format(r2_mean_val) + + if self.verbose: + print('\repoch: %s, loss: %s - val_loss: %s; rmse: %s - rmse_val: %s; r2: %s - r2_val: %s' % (eph, + loss, val_loss, + rmse,rmse_val, + r2_mean,r2_mean_val), + end=100*' '+'\n') + + + if self.criteria == 'val_loss': + current = logs.get(self.criteria) + if current <= self.best: + self.best = current + self.wait = 0 + # Record the best weights if current results is better (less). + self.best_weights = self.model.get_weights() + self.best_epoch = epoch + + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + print('\nRestoring model weights from the end of the best epoch.') + self.model.set_weights(self.best_weights) + + else: + current = np.nanmean(r2_list_val) + + if current >= self.best: + self.best = current + self.wait = 0 + # Record the best weights if current results is better (less). + self.best_weights = self.model.get_weights() + self.best_epoch = epoch + + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + print('\nRestoring model weights from the end of the best epoch.') + self.model.set_weights(self.best_weights) + + def on_train_end(self, logs=None): + self.model.set_weights(self.best_weights) + if self.stopped_epoch > 0: + print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1)) + + + + def evaluate(self, testX, testY): + """evalulate, return rmse and r2""" + y_pred = self.model.predict(testX, verbose=self.verbose) + rmse_list = self.rmse(testY, y_pred) + r2_list = self.r2(testY, y_pred) + return rmse_list, r2_list + + + + + +######## classification ############################### + +class CLA_EarlyStoppingAndPerformance(tf.keras.callbacks.Callback): + + def __init__(self, train_data, valid_data, MASK = -1, patience=5, criteria = 'val_loss', metric = 'ROC', last_avf = None, verbose = 0): + super(CLA_EarlyStoppingAndPerformance, self).__init__() + + sp = ['val_loss', 'val_metric'] + assert criteria in sp, 'not support %s ! only %s' % (criteria, sp) + ms = ['ROC', 'PRC', 'ACC'] + assert metric in ms, 'not support %s ! only %s' % (metric, ms) + ms_dict = {'ROC':'roc_auc', 'PRC':'prc_auc', 'ACC':'accuracy'} + + metric = ms_dict[metric] + val_metric = 'val_%s' % metric + self.metric = metric + self.val_metric = val_metric + + self.x, self.y = train_data + self.x_val, self.y_val = valid_data + self.last_avf = last_avf + + + + self.history = {'loss':[], + 'val_loss':[], + self.metric:[], + self.val_metric:[], + 'epoch':[]} + self.MASK = MASK + self.patience = patience + # best_weights to store the weights at which the minimum loss occurs. + self.best_weights = None + self.criteria = criteria + + + self.best_epoch = 0 + self.verbose = verbose + + def sigmoid(self, x): + s = 1/(1+np.exp(-x)) + return s + + + def roc_auc(self, y_true, y_pred): + if self.last_avf == None: + y_pred_logits = self.sigmoid(y_pred) + else: + y_pred_logits = y_pred + + N_classes = y_pred_logits.shape[1] + + aucs = [] + for i in range(N_classes): + y_pred_one_class = y_pred_logits[:,i] + y_true_one_class = y_true[:, i] + mask = ~(y_true_one_class == self.MASK) + try: + if self.metric == 'roc_auc': + auc = roc_auc_score(y_true_one_class[mask], y_pred_one_class[mask], average='weighted') #ROC_AUC + elif self.metric == 'prc_auc': + auc = prc_auc_score(y_true_one_class[mask], y_pred_one_class[mask]) #PRC_AUC + elif self.metric == 'accuracy': + auc = accuracy_score(y_true_one_class[mask], np.round(y_pred_one_class[mask])) #ACC + except: + auc = np.nan + aucs.append(auc) + return aucs + + + + def on_train_begin(self, logs=None): + # The number of epoch it has waited when loss is no longer minimum. + self.wait = 0 + # The epoch the training stops at. + self.stopped_epoch = 0 + # Initialize the best as infinity. + if self.criteria == 'val_loss': + self.best = np.Inf + else: + self.best = -np.Inf + + + def on_epoch_end(self, epoch, logs={}): + + y_pred = self.model.predict(self.x, verbose = self.verbose) + roc_list = self.roc_auc(self.y, y_pred) + roc_mean = np.nanmean(roc_list) + + y_pred_val = self.model.predict(self.x_val, verbose = self.verbose) + roc_val_list = self.roc_auc(self.y_val, y_pred_val) + roc_val_mean = np.nanmean(roc_val_list) + + self.history['loss'].append(logs.get('loss')) + self.history['val_loss'].append(logs.get('val_loss')) + self.history[self.metric].append(roc_mean) + self.history[self.val_metric].append(roc_val_mean) + self.history['epoch'].append(epoch) + + eph = str(epoch+1).zfill(4) + loss = '{0:.4f}'.format((logs.get('loss'))) + val_loss = '{0:.4f}'.format((logs.get('val_loss'))) + auc = '{0:.4f}'.format(roc_mean) + auc_val = '{0:.4f}'.format(roc_val_mean) + + if self.verbose: + print('\repoch: %s, loss: %s - val_loss: %s; %s: %s - %s: %s' % (eph, + loss, + val_loss, + self.metric, + auc, + self.val_metric, + auc_val), end=100*' '+'\n') + + if self.criteria == 'val_loss': + current = logs.get(self.criteria) + if current <= self.best: + self.best = current + self.wait = 0 + # Record the best weights if current results is better (less). + self.best_weights = self.model.get_weights() + self.best_epoch = epoch + + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + print('\nRestoring model weights from the end of the best epoch.') + self.model.set_weights(self.best_weights) + + else: + current = roc_val_mean + if current >= self.best: + self.best = current + self.wait = 0 + # Record the best weights if current results is better (less). + self.best_weights = self.model.get_weights() + self.best_epoch = epoch + + else: + self.wait += 1 + if self.wait >= self.patience: + self.stopped_epoch = epoch + self.model.stop_training = True + print('\nRestoring model weights from the end of the best epoch.') + self.model.set_weights(self.best_weights) + + def on_train_end(self, logs=None): + self.model.set_weights(self.best_weights) + if self.stopped_epoch > 0: + print('\nEpoch %05d: early stopping' % (self.stopped_epoch + 1)) + + + def evaluate(self, testX, testY): + + y_pred = self.model.predict(testX, verbose = self.verbose) + roc_list = self.roc_auc(testY, y_pred) + return roc_list + +