Diff of /confusion_matrix.py [000000] .. [9c6ad1]

Switch to unified view

a b/confusion_matrix.py
1
import itertools
2
import numpy as np
3
import matplotlib.pyplot as plt
4
from sklearn.metrics import confusion_matrix
5
6
def plot_confusion_matrix(y_true, y_pred, sub, title = "Confusion matrix - 2a",
7
                          cmap=plt.cm.Blues, save_flg=True):
8
9
    y_pred = y_pred.cpu().detach().numpy()
10
    y_true = y_true.cpu().detach().numpy()
11
    classes = [str(i) for i in range(4)]
12
    labels = range(4)
13
14
    cm = confusion_matrix(y_true, y_pred, labels=labels, normalize='true')
15
    plt.figure(figsize=(14, 12))
16
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
17
    plt.title(title, fontsize=40)
18
    plt.colorbar()
19
    tick_marks = np.arange(len(classes))
20
    plt.xticks(tick_marks, classes, fontsize=20)
21
    plt.yticks(tick_marks, classes, fontsize=20)
22
23
    # print('Confusion matrix, without normalization')
24
25
    thresh = cm.max() / 2.
26
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
27
        plt.text(j, i, format(cm[i, j], '.2f'),
28
                 horizontalalignment="center",
29
                 color="white" if cm[i, j] > thresh else "black",
30
                 fontsize=30)
31
32
    plt.ylabel('True label', fontsize=30)
33
    plt.xlabel('Predicted label', fontsize=30)
34
35
    if save_flg:
36
        plt.savefig("confusion_matrix" + str(sub) + ".png")
37
38
    # plt.show()