a b/code/utils.py
1
import os
2
import pandas as pd
3
import numpy as np
4
from matplotlib import pyplot as plt
5
6
7
def cifar10_test_dataset(dir_dataset):
8
9
    files = os.listdir(dir_dataset)
10
11
    Y = []
12
    X = []
13
    for iFile in files:
14
        if 'Other' in iFile:
15
            y = 0
16
        else:
17
            y = iFile.split('_')[-2][-1]
18
19
    return X, Y
20
21
22
def classification_report_csv(report, dir_out):
23
    report_data = []
24
    lines = report.split('\n')
25
    for line in lines[2:-3]:
26
        row = {}
27
        row_data = line.split('      ')
28
        row['class'] = row_data[0]
29
        row['precision'] = float(row_data[1])
30
        row['recall'] = float(row_data[2])
31
        row['f1_score'] = float(row_data[3])
32
        row['support'] = float(row_data[4])
33
        report_data.append(row)
34
    dataframe = pd.DataFrame.from_dict(report_data)
35
    dataframe.to_xlsx(dir_out + 'classification_report.xlsx', index = False)
36
37
38
def plot_image(x, norm_intensity=False):
39
    # channel first
40
    x = np.transpose(x, (1, 2, 0))
41
    if norm_intensity:
42
        x = x / 255.0
43
44
    plt.imshow(x)
45
    plt.axis('off')
46
    plt.show()