a b/Draw_Photos/Draw_Confusion_Matrix.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import pandas as pd
6
import seaborn as sns
7
from sklearn.metrics import confusion_matrix
8
import matplotlib.pyplot as plt
9
import matplotlib as mpl
10
11
# Read labels and scores
12
labels = pd.read_csv('labels.csv', header=None)
13
labels = np.array(labels).astype('float32')
14
prediction = pd.read_csv('prediction.csv', header=None)
15
prediction = np.array(prediction).astype('float32')
16
17
# Set photo parameters
18
sns.set()
19
mpl.rcParams['font.sans-serif'] = 'Times New Roman'
20
mpl.rcParams['axes.unicode_minus'] = False
21
22
# Draw Confusion Matrix
23
f, ax = plt.subplots()
24
C2 = confusion_matrix(labels, prediction, labels=[0, 1, 2, 3])
25
C2 = np.around(C2/sum(C2), 4)
26
# print(C2)
27
28
sns.heatmap(C2, annot=True, ax=ax, fmt='.4f')
29
ax.set(xticklabels=['L', 'R', 'B', 'F'], yticklabels=['L', 'R', 'B', 'F'])
30
ax.set_title('Confusion Matrix of Subject Nine', fontsize=16, fontweight='bold')
31
ax.set_xlabel('Predicted Class', fontsize=14, fontweight='bold')
32
ax.set_ylabel('True Class', fontsize=14, fontweight='bold')
33
plt.savefig('Confusion_Matrix.png', format='png', bbox_inches='tight', transparent=True, dpi=600)
34
plt.show()