|
a |
|
b/code/utils.py |
|
|
1 |
import os |
|
|
2 |
import pandas as pd |
|
|
3 |
import numpy as np |
|
|
4 |
from matplotlib import pyplot as plt |
|
|
5 |
|
|
|
6 |
|
|
|
7 |
def cifar10_test_dataset(dir_dataset): |
|
|
8 |
|
|
|
9 |
files = os.listdir(dir_dataset) |
|
|
10 |
|
|
|
11 |
Y = [] |
|
|
12 |
X = [] |
|
|
13 |
for iFile in files: |
|
|
14 |
if 'Other' in iFile: |
|
|
15 |
y = 0 |
|
|
16 |
else: |
|
|
17 |
y = iFile.split('_')[-2][-1] |
|
|
18 |
|
|
|
19 |
return X, Y |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def classification_report_csv(report, dir_out): |
|
|
23 |
report_data = [] |
|
|
24 |
lines = report.split('\n') |
|
|
25 |
for line in lines[2:-3]: |
|
|
26 |
row = {} |
|
|
27 |
row_data = line.split(' ') |
|
|
28 |
row['class'] = row_data[0] |
|
|
29 |
row['precision'] = float(row_data[1]) |
|
|
30 |
row['recall'] = float(row_data[2]) |
|
|
31 |
row['f1_score'] = float(row_data[3]) |
|
|
32 |
row['support'] = float(row_data[4]) |
|
|
33 |
report_data.append(row) |
|
|
34 |
dataframe = pd.DataFrame.from_dict(report_data) |
|
|
35 |
dataframe.to_xlsx(dir_out + 'classification_report.xlsx', index = False) |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
def plot_image(x, norm_intensity=False): |
|
|
39 |
# channel first |
|
|
40 |
x = np.transpose(x, (1, 2, 0)) |
|
|
41 |
if norm_intensity: |
|
|
42 |
x = x / 255.0 |
|
|
43 |
|
|
|
44 |
plt.imshow(x) |
|
|
45 |
plt.axis('off') |
|
|
46 |
plt.show() |