|
a |
|
b/util.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
from sklearn.metrics import roc_curve, auc |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def results(y_pred, y_test): |
|
|
6 |
""" |
|
|
7 |
:param y_pred: the predicted y-values |
|
|
8 |
:param y_test: the actual y-values |
|
|
9 |
:return: the number of correct predictions, incorrect predictions, and the percent correct |
|
|
10 |
""" |
|
|
11 |
num_right = 0 |
|
|
12 |
num_wrong = 0 |
|
|
13 |
for i in range(len(y_pred)): |
|
|
14 |
if y_pred[i] == y_test[i]: |
|
|
15 |
num_right += 1 |
|
|
16 |
else: |
|
|
17 |
num_wrong += 1 |
|
|
18 |
return num_right/(num_right + num_wrong) |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
def roc_results(y_pred, y_test, model_type): |
|
|
22 |
fpr, tpr, thresholds = roc_curve(y_test, y_pred, pos_label=1) |
|
|
23 |
plt.figure(figsize=(8, 6)) |
|
|
24 |
lw = 2 |
|
|
25 |
plt.plot(fpr, tpr, |
|
|
26 |
lw=lw, label=f'{model_type} (AUC = {round(auc(fpr, tpr), 3)})') |
|
|
27 |
plt.plot([0, 1], [0, 1], lw=lw, linestyle='--') |
|
|
28 |
plt.xlim([0.0, 1.0]) |
|
|
29 |
plt.ylim([0.0, 1.05]) |
|
|
30 |
plt.xlabel('False Positive Rate') |
|
|
31 |
plt.ylabel('True Positive Rate') |
|
|
32 |
plt.title('ROC Curve') |
|
|
33 |
plt.legend(loc="lower right") |
|
|
34 |
plt.show() |