a b/run.py
1
2
import os
3
import torch
4
import random
5
from torch import nn
6
from torch.utils.data import DataLoader, Subset
7
from sklearn.model_selection import train_test_split
8
9
import neural
10
from class_ecgdataset import ECGDataset
11
from class_deepdataset import DeepDatasetV2
12
from train_and_eval import train_model, evaluate_model
13
14
15
def run(NUM_SEGMENTS, NUM_SECONDS, NUM_BATCH, LEADS, NUM_EPOCHS, DATA_PATH, FS, NUM_PATIENTS=None):
16
    '''
17
    Function that:
18
    - loads ECG data, prepares the dataset, performs a stratified train-test split
19
    - trains a DeepECG model and evaluates its performance
20
21
    It operates on a specified number of subjects (if provided), segments, and leads.
22
   
23
    Returns:
24
    - accuracy (float): the accuracy of the model on the test set
25
    - total_train_time (int): the total training time in seconds.
26
    - total_eval_time (int): the total evaluation time in seconds.
27
    '''
28
29
    torch.manual_seed(44)
30
    random.seed(44)
31
32
    device = "cuda" if torch.cuda.is_available() else "cpu" 
33
34
    print(f'Used device: {device}')
35
    print('Loading signals...')
36
37
    patient_ids_path = os.path.join(DATA_PATH, 'patient_ids.txt')
38
    with open(patient_ids_path, 'r') as f:
39
        PATIENT_IDS = [line.strip() for line in f.readlines()]
40
41
    if NUM_PATIENTS is not None:
42
        PATIENT_IDS = PATIENT_IDS[:NUM_PATIENTS]
43
44
    full_dataset = ECGDataset(DATA_PATH, PATIENT_IDS, FS, NUM_SEGMENTS, NUM_SECONDS, LEADS)
45
46
    # stratified split to maintain equal distribution of each patient's data
47
    train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.5, stratify=full_dataset.labels)
48
    
49
    train_dataset = Subset(full_dataset, train_indices)
50
    test_dataset = Subset(full_dataset, test_indices)
51
52
    train_loader = DataLoader(train_dataset, batch_size=NUM_BATCH, shuffle=True)
53
    test_loader = DataLoader(test_dataset, batch_size=NUM_BATCH, shuffle=False)
54
55
    single_batch_segments, _ = next(iter(train_loader))
56
    output_shape = len(PATIENT_IDS) 
57
    hidden_units = 32 
58
    
59
    dummy_network = neural.DeepECG_DUMMY(len(LEADS), hidden_units, output_shape).to(device)
60
    single_batch_segments = single_batch_segments.to(device)
61
    final_features = dummy_network(single_batch_segments)
62
63
    model = neural.DeepECG(len(LEADS), hidden_units, output_shape, final_features).to(device)
64
65
    loss_fn = nn.CrossEntropyLoss()
66
    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)
67
68
69
    total_train_time = train_model(model=model,
70
                                   data_loader=train_loader, 
71
                                   loss_fn=loss_fn,
72
                                   optimizer=optimizer,
73
                                   device=device, 
74
                                   num_epochs=NUM_EPOCHS,
75
                                   output_shape = output_shape
76
                                   )
77
                    
78
            
79
    accuracy, total_eval_time = evaluate_model(model=model, 
80
                                               test_loader=test_loader,
81
                                               loss_fn=loss_fn,
82
                                               device=device,
83
                                               output_shape = output_shape
84
                                               )
85
    return accuracy, total_train_time, total_eval_time
86
87
88
89
90
def run_deepecg(NUM_BATCH, LEAD, NUM_EPOCHS, DATA_PATH, FS):
91
    '''
92
    Function that:
93
    - loads ECG data, prepares the dataset using DeepECG methodology, performs a stratified train-test split
94
    - trains a DeepECG model and evaluates its performance
95
96
    It operates on a specified lead of ECG recordings.
97
   
98
    Returns:
99
    - accuracy (float): the accuracy of the model on the test set
100
    - total_train_time (int): the total training time in seconds.
101
    - total_eval_time (int): the total evaluation time in seconds.
102
    '''
103
104
    torch.manual_seed(44)
105
    random.seed(44)
106
107
    device = "cuda" if torch.cuda.is_available() else "cpu" 
108
    
109
    print(f'Used device: {device}')
110
    print('Loading signals...')
111
112
    patient_ids_path = os.path.join(DATA_PATH, 'patient_ids.txt')
113
    with open(patient_ids_path, 'r') as f:
114
        PATIENT_IDS = [line.strip() for line in f.readlines()]
115
116
    full_dataset = DeepDatasetV2(DATA_PATH, PATIENT_IDS, FS, LEAD)
117
118
    # stratified split to maintain equal distribution of each patient's data
119
    train_indices, test_indices = train_test_split(range(len(full_dataset)), test_size=0.5, stratify=full_dataset.labels)
120
121
    train_dataset = Subset(full_dataset, train_indices)
122
    test_dataset = Subset(full_dataset, test_indices)
123
124
    train_loader = DataLoader(train_dataset, batch_size=NUM_BATCH, shuffle=True)
125
    test_loader = DataLoader(test_dataset, batch_size=NUM_BATCH, shuffle=False)
126
127
    single_batch_segments, _ = next(iter(train_loader))
128
    num_leads = 1
129
    output_shape = len(PATIENT_IDS) 
130
    hidden_units = 32 
131
    
132
    dummy_network = neural.DeepECG_DUMMY(num_leads, hidden_units, output_shape).to(device)
133
    single_batch_segments = single_batch_segments.to(device)
134
    final_features = dummy_network(single_batch_segments)
135
136
    model = neural.DeepECG(num_leads, hidden_units, output_shape, final_features).to(device)
137
138
    loss_fn = nn.CrossEntropyLoss()
139
    optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)
140
141
142
    total_train_time = train_model(model=model,
143
                                   data_loader=train_loader, 
144
                                   loss_fn=loss_fn,
145
                                   optimizer=optimizer,
146
                                   device=device, 
147
                                   num_epochs=NUM_EPOCHS,
148
                                   output_shape = output_shape
149
                                   )
150
                    
151
            
152
    accuracy, total_eval_time = evaluate_model(model=model, 
153
                                               test_loader=test_loader,
154
                                               loss_fn=loss_fn,
155
                                               device=device,
156
                                               output_shape = output_shape
157
                                               )
158
159
    return accuracy, total_train_time, total_eval_time