a b/run_exp_2.py
1
import os
2
import sys
3
import numpy as np
4
import seaborn as sns
5
import matplotlib.pyplot as plt
6
from run import run
7
8
from matplotlib.colors import LinearSegmentedColormap
9
10
11
def run_experiment2(data_path):
12
    '''
13
    Function that conducts a series of experiments with varying:
14
     - number of segments used for training and testing (N)
15
     - segment length (T)
16
17
    Parameters:
18
    data_path (str): path to the preprocessed dataset.
19
20
    Saves:
21
    - CSV files for accuracy, training time, and evaluation time grids, and 
22
    - a heatmap plot of model accuracy in the 'exp2_results' directory.
23
    '''
24
25
    segments_range = [500,350,200,100,50]
26
    seconds_range = [1,2,5,7,10]
27
    leads = [0,1]
28
29
    accuracy_grid = np.zeros((len(segments_range), len(seconds_range)))
30
    train_time_grid = np.zeros_like(accuracy_grid)
31
    eval_time_grid = np.zeros_like(accuracy_grid)
32
33
34
    for i, num_segments in enumerate(segments_range):
35
        for j, num_seconds in enumerate(seconds_range):
36
            print(f'Training for {num_segments} segments, {num_seconds} s long')
37
            accuracy, total_train_time, total_eval_time = run(NUM_SEGMENTS=num_segments, 
38
                                                              NUM_SECONDS=num_seconds, 
39
                                                              NUM_BATCH=16, 
40
                                                              LEADS=leads, 
41
                                                              NUM_EPOCHS=200, 
42
                                                              DATA_PATH = data_path, 
43
                                                              FS=128)
44
            accuracy_grid[i, j] = accuracy
45
            train_time_grid[i,j] = total_train_time
46
            eval_time_grid[i,j] = total_eval_time
47
48
49
    os.makedirs('exp2_results', exist_ok=True)
50
    np.savetxt('exp2_results/accuracy_grid.csv', accuracy_grid, delimiter=',', fmt='%.2f')
51
    np.savetxt('exp2_results/train_time_grid.csv', train_time_grid, delimiter=',', fmt='%.2f')
52
    np.savetxt('exp2_results/eval_time_grid.csv', eval_time_grid, delimiter=',', fmt='%.2f')
53
54
    color_list = ["#FCDDD9", "#FABEB7", "#FAA99F", "#FA968A", "#F87F71",
55
                "#DC7366", "#B15F56", "#884C45", "#593633", "#2E2322"]
56
    custom_colormap = LinearSegmentedColormap.from_list("custom_salmon", color_list)
57
58
    yticklabels = [int(w/2) for w in segments_range]
59
    plt.figure(figsize=(10, 8), dpi=100)
60
    sns.heatmap(accuracy_grid, annot=True, fmt=".2f", xticklabels=seconds_range, yticklabels=yticklabels, 
61
                cmap=custom_colormap, cbar_kws={'label': 'Accuracy %'})
62
    plt.xlabel('NUMBER OF SECONDS (T)')
63
    plt.ylabel('TRAIN SAMPLES PER SUBJECT (N)')
64
    plt.savefig('exp2_results/fig_exp2.png')
65
    plt.close()
66
67
if __name__ == "__main__":
68
    if len(sys.argv) != 2:
69
        print("Usage: python run_exp_2.py <data_path>")
70
        sys.exit(1)
71
72
    data_path = sys.argv[1]
73
74
    run_experiment2(data_path=data_path)