--- a +++ b/run_exp_2.py @@ -0,0 +1,74 @@ +import os +import sys +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt +from run import run + +from matplotlib.colors import LinearSegmentedColormap + + +def run_experiment2(data_path): + ''' + Function that conducts a series of experiments with varying: + - number of segments used for training and testing (N) + - segment length (T) + + Parameters: + data_path (str): path to the preprocessed dataset. + + Saves: + - CSV files for accuracy, training time, and evaluation time grids, and + - a heatmap plot of model accuracy in the 'exp2_results' directory. + ''' + + segments_range = [500,350,200,100,50] + seconds_range = [1,2,5,7,10] + leads = [0,1] + + accuracy_grid = np.zeros((len(segments_range), len(seconds_range))) + train_time_grid = np.zeros_like(accuracy_grid) + eval_time_grid = np.zeros_like(accuracy_grid) + + + for i, num_segments in enumerate(segments_range): + for j, num_seconds in enumerate(seconds_range): + print(f'Training for {num_segments} segments, {num_seconds} s long') + accuracy, total_train_time, total_eval_time = run(NUM_SEGMENTS=num_segments, + NUM_SECONDS=num_seconds, + NUM_BATCH=16, + LEADS=leads, + NUM_EPOCHS=200, + DATA_PATH = data_path, + FS=128) + accuracy_grid[i, j] = accuracy + train_time_grid[i,j] = total_train_time + eval_time_grid[i,j] = total_eval_time + + + os.makedirs('exp2_results', exist_ok=True) + np.savetxt('exp2_results/accuracy_grid.csv', accuracy_grid, delimiter=',', fmt='%.2f') + np.savetxt('exp2_results/train_time_grid.csv', train_time_grid, delimiter=',', fmt='%.2f') + np.savetxt('exp2_results/eval_time_grid.csv', eval_time_grid, delimiter=',', fmt='%.2f') + + color_list = ["#FCDDD9", "#FABEB7", "#FAA99F", "#FA968A", "#F87F71", + "#DC7366", "#B15F56", "#884C45", "#593633", "#2E2322"] + custom_colormap = LinearSegmentedColormap.from_list("custom_salmon", color_list) + + yticklabels = [int(w/2) for w in segments_range] + plt.figure(figsize=(10, 8), dpi=100) + sns.heatmap(accuracy_grid, annot=True, fmt=".2f", xticklabels=seconds_range, yticklabels=yticklabels, + cmap=custom_colormap, cbar_kws={'label': 'Accuracy %'}) + plt.xlabel('NUMBER OF SECONDS (T)') + plt.ylabel('TRAIN SAMPLES PER SUBJECT (N)') + plt.savefig('exp2_results/fig_exp2.png') + plt.close() + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python run_exp_2.py <data_path>") + sys.exit(1) + + data_path = sys.argv[1] + + run_experiment2(data_path=data_path)