--- a
+++ b/callbacks/eval.py
@@ -0,0 +1,69 @@
+'''
+@Author: your name
+@Date: 2020-01-06 14:04:27
+@LastEditTime : 2020-01-06 17:28:15
+@LastEditors  : Please set LastEditors
+@Description: In User Settings Edit
+@FilePath: /KGCN_Keras-master/callbacks/eval.py
+'''
+# -*- coding: utf-8 -*-
+
+from collections import defaultdict
+
+import numpy as np
+from keras.callbacks import Callback
+from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, average_precision_score,precision_recall_curve
+import sklearn.metrics as m
+from utils import write_log
+
+#添加指标:ACC, AUPR, AUC-ROC, F1 +std
+
+class KGCNMetric(Callback):
+    def __init__(self, x_train, y_train, x_valid, y_valid,aggregator_type,dataset,K_fold):
+        self.x_train = x_train
+        self.y_train = y_train
+        self.x_valid = x_valid
+        self.y_valid = y_valid
+        self.aggregator_type=aggregator_type
+        self.dataset=dataset
+        self.k=K_fold
+        self.threshold=0.5
+        # self.user_list, self.train_record, self.valid_record, \
+        #     self.item_set, self.k_list = self.topk_settings()
+
+        super(KGCNMetric, self).__init__()
+
+    def on_epoch_end(self, epoch, logs=None):
+        y_pred = self.model.predict(self.x_valid).flatten()
+        y_true = self.y_valid.flatten()
+        auc = roc_auc_score(y_true=y_true, y_score=y_pred)# roc曲线的auc
+        precision, recall, _thresholds = precision_recall_curve(y_true=y_true, probas_pred=y_pred)
+        aupr=m.auc(recall,precision)
+        y_pred = [1 if prob >= self.threshold else 0 for prob in y_pred]
+        acc = accuracy_score(y_true=y_true, y_pred=y_pred)
+        f1 = f1_score(y_true=y_true, y_pred=y_pred)
+        
+        print(type(aupr))
+        logs['val_aupr']=float(aupr)
+        logs['val_auc'] = float(auc)
+        logs['val_acc'] = float(acc)
+        logs['val_f1'] = float(f1)
+        
+        logs['dataset']=self.dataset
+        logs['aggregator_type']=self.aggregator_type
+        logs['kfold']=self.k
+        logs['epoch_count']=epoch+1
+        print(f'Logging Info - epoch: {epoch+1}, val_auc: {auc}, val_aupr: {aupr}, val_acc: {acc}, val_f1: {f1}')
+        write_log('log/train_history.txt',logs,mode='a')
+
+    @staticmethod
+    def get_user_record(data, is_train):
+        user_history_dict = defaultdict(set)
+        for interaction in data:
+            user = interaction[0]
+            item = interaction[1]
+            label = interaction[2]
+            if is_train or label == 1:
+                user_history_dict[user].add(item)
+        return user_history_dict
+