--- a +++ b/confusion_matrix.py @@ -0,0 +1,38 @@ +import itertools +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix + +def plot_confusion_matrix(y_true, y_pred, sub, title = "Confusion matrix - 2a", + cmap=plt.cm.Blues, save_flg=True): + + y_pred = y_pred.cpu().detach().numpy() + y_true = y_true.cpu().detach().numpy() + classes = [str(i) for i in range(4)] + labels = range(4) + + cm = confusion_matrix(y_true, y_pred, labels=labels, normalize='true') + plt.figure(figsize=(14, 12)) + plt.imshow(cm, interpolation='nearest', cmap=cmap) + plt.title(title, fontsize=40) + plt.colorbar() + tick_marks = np.arange(len(classes)) + plt.xticks(tick_marks, classes, fontsize=20) + plt.yticks(tick_marks, classes, fontsize=20) + + # print('Confusion matrix, without normalization') + + thresh = cm.max() / 2. + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + plt.text(j, i, format(cm[i, j], '.2f'), + horizontalalignment="center", + color="white" if cm[i, j] > thresh else "black", + fontsize=30) + + plt.ylabel('True label', fontsize=30) + plt.xlabel('Predicted label', fontsize=30) + + if save_flg: + plt.savefig("confusion_matrix" + str(sub) + ".png") + + # plt.show()