Diff of /util.py [000000] .. [48f029]

Switch to unified view

a b/util.py
1
import matplotlib.pyplot as plt
2
from sklearn.metrics import roc_curve, auc
3
4
5
def results(y_pred, y_test):
6
    """
7
    :param y_pred: the predicted y-values
8
    :param y_test: the actual y-values
9
    :return: the number of correct predictions, incorrect predictions, and the percent correct
10
    """
11
    num_right = 0
12
    num_wrong = 0
13
    for i in range(len(y_pred)):
14
        if y_pred[i] == y_test[i]:
15
            num_right += 1
16
        else:
17
            num_wrong += 1
18
    return num_right/(num_right + num_wrong)
19
20
21
def roc_results(y_pred, y_test, model_type):
22
    fpr, tpr, thresholds = roc_curve(y_test, y_pred, pos_label=1)
23
    plt.figure(figsize=(8, 6))
24
    lw = 2
25
    plt.plot(fpr, tpr,
26
             lw=lw, label=f'{model_type} (AUC = {round(auc(fpr, tpr), 3)})')
27
    plt.plot([0, 1], [0, 1], lw=lw, linestyle='--')
28
    plt.xlim([0.0, 1.0])
29
    plt.ylim([0.0, 1.05])
30
    plt.xlabel('False Positive Rate')
31
    plt.ylabel('True Positive Rate')
32
    plt.title('ROC Curve')
33
    plt.legend(loc="lower right")
34
    plt.show()