--- a +++ b/code/utils.py @@ -0,0 +1,46 @@ +import os +import pandas as pd +import numpy as np +from matplotlib import pyplot as plt + + +def cifar10_test_dataset(dir_dataset): + + files = os.listdir(dir_dataset) + + Y = [] + X = [] + for iFile in files: + if 'Other' in iFile: + y = 0 + else: + y = iFile.split('_')[-2][-1] + + return X, Y + + +def classification_report_csv(report, dir_out): + report_data = [] + lines = report.split('\n') + for line in lines[2:-3]: + row = {} + row_data = line.split(' ') + row['class'] = row_data[0] + row['precision'] = float(row_data[1]) + row['recall'] = float(row_data[2]) + row['f1_score'] = float(row_data[3]) + row['support'] = float(row_data[4]) + report_data.append(row) + dataframe = pd.DataFrame.from_dict(report_data) + dataframe.to_xlsx(dir_out + 'classification_report.xlsx', index = False) + + +def plot_image(x, norm_intensity=False): + # channel first + x = np.transpose(x, (1, 2, 0)) + if norm_intensity: + x = x / 255.0 + + plt.imshow(x) + plt.axis('off') + plt.show()