|
a |
|
b/confusion_matrix.py |
|
|
1 |
import itertools |
|
|
2 |
import numpy as np |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
from sklearn.metrics import confusion_matrix |
|
|
5 |
|
|
|
6 |
def plot_confusion_matrix(y_true, y_pred, sub, title = "Confusion matrix - 2a", |
|
|
7 |
cmap=plt.cm.Blues, save_flg=True): |
|
|
8 |
|
|
|
9 |
y_pred = y_pred.cpu().detach().numpy() |
|
|
10 |
y_true = y_true.cpu().detach().numpy() |
|
|
11 |
classes = [str(i) for i in range(4)] |
|
|
12 |
labels = range(4) |
|
|
13 |
|
|
|
14 |
cm = confusion_matrix(y_true, y_pred, labels=labels, normalize='true') |
|
|
15 |
plt.figure(figsize=(14, 12)) |
|
|
16 |
plt.imshow(cm, interpolation='nearest', cmap=cmap) |
|
|
17 |
plt.title(title, fontsize=40) |
|
|
18 |
plt.colorbar() |
|
|
19 |
tick_marks = np.arange(len(classes)) |
|
|
20 |
plt.xticks(tick_marks, classes, fontsize=20) |
|
|
21 |
plt.yticks(tick_marks, classes, fontsize=20) |
|
|
22 |
|
|
|
23 |
# print('Confusion matrix, without normalization') |
|
|
24 |
|
|
|
25 |
thresh = cm.max() / 2. |
|
|
26 |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): |
|
|
27 |
plt.text(j, i, format(cm[i, j], '.2f'), |
|
|
28 |
horizontalalignment="center", |
|
|
29 |
color="white" if cm[i, j] > thresh else "black", |
|
|
30 |
fontsize=30) |
|
|
31 |
|
|
|
32 |
plt.ylabel('True label', fontsize=30) |
|
|
33 |
plt.xlabel('Predicted label', fontsize=30) |
|
|
34 |
|
|
|
35 |
if save_flg: |
|
|
36 |
plt.savefig("confusion_matrix" + str(sub) + ".png") |
|
|
37 |
|
|
|
38 |
# plt.show() |