Diff of /run_exp_1.py [000000] .. [6536f9]

Switch to unified view

a b/run_exp_1.py
1
import run
2
import sys
3
import os
4
import numpy as np
5
import matplotlib.pyplot as plt
6
7
def run_experiment1(data_path):
8
    os.makedirs('exp1_results', exist_ok=True)
9
10
    accuracies1 = []
11
    accuracies2 = []
12
    accuracies3 = []
13
    accuracies4 = []
14
    accuracies5 = []
15
16
    for num_subjects in range(20, 150, 20):
17
        print(f"Training for {num_subjects} subjects with III+V3+V5")
18
        accuracy, total_train_time, total_eval_time  = run.run(NUM_SEGMENTS=500, 
19
                                                              NUM_SECONDS=2, 
20
                                                              NUM_BATCH=16, 
21
                                                              LEADS=[0,1,2], 
22
                                                              NUM_EPOCHS=200, 
23
                                                              DATA_PATH = data_path, 
24
                                                              FS=128,
25
                                                              NUM_PATIENTS=num_subjects)
26
        accuracies1.append(accuracy)
27
        with open('exp1_results/result-III-V3-V5.txt', 'a') as f:
28
            f.write(f"With number subjects: {num_subjects}:\n")
29
            f.write(f"Accuracy: {accuracy}, Time train: {total_train_time}, Time eval: {total_eval_time}\n\n")
30
    
31
32
    for num_subjects in range(20, 150, 20):
33
        print(f"Training for {num_subjects} subjects with III+V3")
34
        accuracy, total_train_time, total_eval_time  = run.run(NUM_SEGMENTS=500, 
35
                                                              NUM_SECONDS=2, 
36
                                                              NUM_BATCH=16, 
37
                                                              LEADS=[0,1], 
38
                                                              NUM_EPOCHS=200, 
39
                                                              DATA_PATH = data_path, 
40
                                                              FS=128,
41
                                                              NUM_PATIENTS=num_subjects)
42
        accuracies2.append(accuracy)
43
        with open('exp1_results/result-III-V3.txt', 'a') as f:
44
            f.write(f"With number subjects: {num_subjects}:\n")
45
            f.write(f"Accuracy: {accuracy}, Time train: {total_train_time}, Time eval: {total_eval_time}\n\n")
46
47
    for num_subjects in range(20, 150, 20):
48
        print(f"Training for {num_subjects} subjects with V5")
49
        accuracy, total_train_time, total_eval_time  = run.run(NUM_SEGMENTS=500, 
50
                                                              NUM_SECONDS=2, 
51
                                                              NUM_BATCH=16, 
52
                                                              LEADS=[2], 
53
                                                              NUM_EPOCHS=200, 
54
                                                              DATA_PATH = data_path, 
55
                                                              FS=128,
56
                                                              NUM_PATIENTS=num_subjects)
57
        accuracies3.append(accuracy)
58
        with open('exp1_results/result-V5.txt', 'a') as f:
59
            f.write(f"With number subjects: {num_subjects}:\n")
60
            f.write(f"Accuracy: {accuracy}, Time train: {total_train_time}, Time eval: {total_eval_time}\n\n")
61
62
    for num_subjects in range(20, 150, 20):
63
        print(f"Training for {num_subjects} subjects with V3")
64
        accuracy, total_train_time, total_eval_time  = run.run(NUM_SEGMENTS=500, 
65
                                                              NUM_SECONDS=2, 
66
                                                              NUM_BATCH=16, 
67
                                                              LEADS=[1], 
68
                                                              NUM_EPOCHS=200, 
69
                                                              DATA_PATH = data_path, 
70
                                                              FS=128,
71
                                                              NUM_PATIENTS=num_subjects)
72
        accuracies4.append(accuracy)
73
        with open('exp1_results/result-V3.txt', 'a') as f:
74
            f.write(f"With number subjects: {num_subjects}:\n")
75
            f.write(f"Accuracy: {accuracy}, Time train: {total_train_time}, Time eval: {total_eval_time}\n\n")
76
77
    for num_subjects in range(20, 150, 20):
78
        print(f"Training for {num_subjects} subjects with III")
79
        accuracy, total_train_time, total_eval_time  = run.run(NUM_SEGMENTS=500, 
80
                                                              NUM_SECONDS=2, 
81
                                                              NUM_BATCH=16, 
82
                                                              LEADS=[0], 
83
                                                              NUM_EPOCHS=200, 
84
                                                              DATA_PATH = data_path, 
85
                                                              FS=128,
86
                                                              NUM_PATIENTS=num_subjects)
87
        accuracies5.append(accuracy)
88
        with open('exp1_results/result-III.txt', 'a') as f:
89
            f.write(f"With number subjects: {num_subjects}:\n")
90
            f.write(f"Accuracy: {accuracy}, Time train: {total_train_time}, Time eval: {total_eval_time}\n\n")
91
92
93
    #plotting             
94
    num_subjects = list(range(20, 150, 20))
95
    plt.figure(figsize=(10, 8), dpi=100)
96
    plt.plot(num_subjects, accuracies1, '#8C2155',marker='.', label='III+V3+V5')
97
    plt.plot(num_subjects, accuracies2, '#F99083',marker='.', label='III+V3')
98
    plt.plot(num_subjects, accuracies3, '#8AA29E',marker='.', label='V5')
99
    plt.plot(num_subjects, accuracies4, '#98CE00',marker='.', label='V3')
100
    plt.plot(num_subjects, accuracies5, '#FFC857',marker='.', label='III')
101
102
    plt.xlabel('Number of subjects', fontsize=16)
103
    plt.ylabel('Accuracy (%)', fontsize=16)
104
    plt.tick_params(axis='both', which='major', labelsize=12)
105
    plt.legend()
106
    plt.xlim((np.min(num_subjects),np.max(num_subjects)))
107
    plt.grid(True, 'both')
108
    plt.savefig('exp1_results/fig_exp1.png')
109
    plt.show()
110
    plt.close()
111
112
113
if __name__ == "__main__":
114
    if len(sys.argv) != 2:
115
        print("Usage: python run_exp_1.py <data_path>")
116
        sys.exit(1)
117
    data_path = sys.argv[1]
118
119
    run_experiment1(data_path)
120
121