|
a |
|
b/Draw_Photos/Draw_Confusion_Matrix.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# -*- coding: utf-8 -*- |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import pandas as pd |
|
|
6 |
import seaborn as sns |
|
|
7 |
from sklearn.metrics import confusion_matrix |
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
import matplotlib as mpl |
|
|
10 |
|
|
|
11 |
# Read labels and scores |
|
|
12 |
labels = pd.read_csv('labels.csv', header=None) |
|
|
13 |
labels = np.array(labels).astype('float32') |
|
|
14 |
prediction = pd.read_csv('prediction.csv', header=None) |
|
|
15 |
prediction = np.array(prediction).astype('float32') |
|
|
16 |
|
|
|
17 |
# Set photo parameters |
|
|
18 |
sns.set() |
|
|
19 |
mpl.rcParams['font.sans-serif'] = 'Times New Roman' |
|
|
20 |
mpl.rcParams['axes.unicode_minus'] = False |
|
|
21 |
|
|
|
22 |
# Draw Confusion Matrix |
|
|
23 |
f, ax = plt.subplots() |
|
|
24 |
C2 = confusion_matrix(labels, prediction, labels=[0, 1, 2, 3]) |
|
|
25 |
C2 = np.around(C2/sum(C2), 4) |
|
|
26 |
# print(C2) |
|
|
27 |
|
|
|
28 |
sns.heatmap(C2, annot=True, ax=ax, fmt='.4f') |
|
|
29 |
ax.set(xticklabels=['L', 'R', 'B', 'F'], yticklabels=['L', 'R', 'B', 'F']) |
|
|
30 |
ax.set_title('Confusion Matrix of Subject Nine', fontsize=16, fontweight='bold') |
|
|
31 |
ax.set_xlabel('Predicted Class', fontsize=14, fontweight='bold') |
|
|
32 |
ax.set_ylabel('True Class', fontsize=14, fontweight='bold') |
|
|
33 |
plt.savefig('Confusion_Matrix.png', format='png', bbox_inches='tight', transparent=True, dpi=600) |
|
|
34 |
plt.show() |