Diff of /callbacks/eval.py [000000] .. [c0da92]

Switch to unified view

a b/callbacks/eval.py
1
'''
2
@Author: your name
3
@Date: 2020-01-06 14:04:27
4
@LastEditTime : 2020-01-06 17:28:15
5
@LastEditors  : Please set LastEditors
6
@Description: In User Settings Edit
7
@FilePath: /KGCN_Keras-master/callbacks/eval.py
8
'''
9
# -*- coding: utf-8 -*-
10
11
from collections import defaultdict
12
13
import numpy as np
14
from keras.callbacks import Callback
15
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, average_precision_score,precision_recall_curve
16
import sklearn.metrics as m
17
from utils import write_log
18
19
#添加指标:ACC, AUPR, AUC-ROC, F1 +std
20
21
class KGCNMetric(Callback):
22
    def __init__(self, x_train, y_train, x_valid, y_valid,aggregator_type,dataset,K_fold):
23
        self.x_train = x_train
24
        self.y_train = y_train
25
        self.x_valid = x_valid
26
        self.y_valid = y_valid
27
        self.aggregator_type=aggregator_type
28
        self.dataset=dataset
29
        self.k=K_fold
30
        self.threshold=0.5
31
        # self.user_list, self.train_record, self.valid_record, \
32
        #     self.item_set, self.k_list = self.topk_settings()
33
34
        super(KGCNMetric, self).__init__()
35
36
    def on_epoch_end(self, epoch, logs=None):
37
        y_pred = self.model.predict(self.x_valid).flatten()
38
        y_true = self.y_valid.flatten()
39
        auc = roc_auc_score(y_true=y_true, y_score=y_pred)# roc曲线的auc
40
        precision, recall, _thresholds = precision_recall_curve(y_true=y_true, probas_pred=y_pred)
41
        aupr=m.auc(recall,precision)
42
        y_pred = [1 if prob >= self.threshold else 0 for prob in y_pred]
43
        acc = accuracy_score(y_true=y_true, y_pred=y_pred)
44
        f1 = f1_score(y_true=y_true, y_pred=y_pred)
45
        
46
        print(type(aupr))
47
        logs['val_aupr']=float(aupr)
48
        logs['val_auc'] = float(auc)
49
        logs['val_acc'] = float(acc)
50
        logs['val_f1'] = float(f1)
51
        
52
        logs['dataset']=self.dataset
53
        logs['aggregator_type']=self.aggregator_type
54
        logs['kfold']=self.k
55
        logs['epoch_count']=epoch+1
56
        print(f'Logging Info - epoch: {epoch+1}, val_auc: {auc}, val_aupr: {aupr}, val_acc: {acc}, val_f1: {f1}')
57
        write_log('log/train_history.txt',logs,mode='a')
58
59
    @staticmethod
60
    def get_user_record(data, is_train):
61
        user_history_dict = defaultdict(set)
62
        for interaction in data:
63
            user = interaction[0]
64
            item = interaction[1]
65
            label = interaction[2]
66
            if is_train or label == 1:
67
                user_history_dict[user].add(item)
68
        return user_history_dict
69