Diff of /training_pipeline.py [000000] .. [4dadda]

Switch to unified view

a b/training_pipeline.py
1
from sklearn.model_selection import train_test_split
2
import pandas as pd
3
import numpy as np
4
import torch
5
from joblib import load
6
import statistics as stats
7
from sklearn import preprocessing
8
9
import torch.backends.cudnn as cudnn
10
cudnn.enabled = True
11
cudnn.benchmark = False
12
cudnn.deterministic = True
13
14
from code_psd_shallow_eeg_gcnn.EEGGraphDataset import EEGGraphDataset
15
from code_psd_shallow_eeg_gcnn.EEGGraphConvNet import EEGGraphConvNet
16
from torch_geometric.data import DataLoader
17
from torch.utils.data import WeightedRandomSampler
18
from sklearn.metrics import make_scorer
19
from sklearn.metrics import balanced_accuracy_score, auc, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
20
from torchvision.transforms import Compose, ToTensor
21
22
stats_test_data = { }
23
24
# after each epoch, record all the metrics on both train and validation sets
25
def collect_metrics(y_probs_test, y_true_test, y_pred_test, sample_indices_test,
26
                    fold_idx, experiment_name):
27
28
    dataset_index = pd.read_csv("master_metadata_index.csv", dtype={"patient_ID":str, })
29
30
    # create patient-level train and test dataframes
31
    rows = [ ]
32
    for i in range(len(sample_indices_test)):
33
        idx = sample_indices_test[i]
34
        temp = { }
35
        temp["patient_ID"] = str(dataset_index.loc[idx, "patient_ID"])
36
        temp["sample_idx"] = idx
37
        temp["y_true"] = y_true_test[i]
38
        temp["y_probs_0"] = y_probs_test[i, 0]
39
        temp["y_probs_1"] = y_probs_test[i, 1]
40
        temp["y_pred"] = y_pred_test[i]
41
        rows.append(temp)
42
    test_patient_df = pd.DataFrame(rows)
43
44
    # get patient-level metrics from window-level dataframes
45
    y_probs_test_patient, y_true_test_patient, y_pred_test_patient = get_patient_prediction(test_patient_df, fold_idx)
46
47
    stats_test_data[f"probs_0_fold_{fold_idx}"] = y_probs_test_patient[:, 0]
48
    stats_test_data[f"probs_1_fold_{fold_idx}"] = y_probs_test_patient[:, 1]
49
50
    window_csv_dict = { }
51
    patient_csv_dict = { }
52
53
    # WINDOW-LEVEL ROC PLOT
54
    # pos_label="healthy"
55
    fpr, tpr, thresholds = roc_curve(y_true_test, y_probs_test[:,1], pos_label=1)
56
    window_csv_dict[f"fpr_fold_{fold_idx}"] = fpr
57
    window_csv_dict[f"tpr_fold_{fold_idx}"] = tpr
58
    window_csv_dict[f"thres_fold_{fold_idx}"] = thresholds
59
60
    # PATIENT-LEVEL ROC PLOT - select optimal threshold for this, and get patient-level precision, recall, f1
61
    # pos_label="healthy"
62
    fpr, tpr, thresholds = roc_curve(y_true_test_patient, y_probs_test_patient[:,1], pos_label=1)
63
    patient_csv_dict[f"fpr_fold_{fold_idx}"] = fpr
64
    patient_csv_dict[f"tpr_fold_{fold_idx}"] = tpr
65
    patient_csv_dict[f"thres_fold_{fold_idx}"] = thresholds
66
67
    # select an optimal threshold using the ROC curve
68
    # Youden's J statistic to obtain the optimal probability threshold and this method gives equal weights to both false positives and false negatives
69
    optimal_proba_cutoff = sorted(list(zip(np.abs(tpr - fpr), thresholds)), key=lambda i: i[0], reverse=True)[0][1]
70
    # print (optimal_proba_cutoff)
71
72
    # calculate class predictions and confusion-based metrics using the optimal threshold
73
    roc_predictions = [1 if i >= optimal_proba_cutoff else 0 for i in y_probs_test_patient[:,1]]
74
75
    precision_patient_test =  precision_score(y_true_test_patient, roc_predictions, pos_label=0)
76
    recall_patient_test =  recall_score(y_true_test_patient, roc_predictions, pos_label=0)
77
    f1_patient_test = f1_score(y_true_test_patient, roc_predictions, pos_label=0)
78
    bal_acc_patient_test = balanced_accuracy_score(y_true_test_patient, roc_predictions)
79
80
81
    # PATIENT-LEVEL AUROC
82
    from sklearn.metrics import roc_auc_score
83
    auroc_patient_test = roc_auc_score(y_true_test_patient, y_probs_test_patient[:,1])
84
85
    # AUROC
86
    from sklearn.metrics import roc_auc_score
87
    # CAUTION - The binary case expects a shape (n_samples,), and the scores must be the scores of the class with the greater label.
88
    # https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
89
    auroc_test = roc_auc_score(y_true_test, y_probs_test[:,1])
90
    
91
    return auroc_patient_test, auroc_test, precision_patient_test, recall_patient_test, f1_patient_test, bal_acc_patient_test
92
93
# create patient-level metrics
94
def get_patient_prediction(df, fold_idx):
95
    unique_patients = list(df["patient_ID"].unique())
96
    grouped_df = df.groupby("patient_ID")
97
    rows = [ ]
98
    for patient in unique_patients:
99
        patient_df = grouped_df.get_group(patient)
100
        temp = { }
101
        temp["patient_ID"] = patient
102
        temp["y_true"] = list(patient_df["y_true"].unique())[0]
103
        assert len(list(patient_df["y_true"].unique())) == 1
104
        temp["y_pred"] = patient_df["y_pred"].mode()[0]
105
        temp["y_probs_0"] = patient_df["y_probs_0"].mean()
106
        temp["y_probs_1"] = patient_df["y_probs_1"].mean()
107
        rows.append(temp)
108
    return_df = pd.DataFrame(rows)
109
110
    # need subject names and labels for comparisons testing
111
    if fold_idx == 0:
112
        stats_test_data["subject_id"] = list(return_df["patient_ID"][:])
113
        stats_test_data["label"] = return_df["y_true"][:]
114
115
    return np.array(list(zip(return_df["y_probs_0"], return_df["y_probs_1"]))), list(return_df["y_true"]), list(return_df["y_pred"])
116
117
118
if __name__ == "__main__":
119
120
    GPU_IDX = 0
121
    EXPERIMENT_NAME = "psd_gnn_shallow"
122
    BATCH_SIZE = 512
123
    SFREQ = 250.0
124
    NUM_EPOCHS = 100
125
    NUM_WORKERS = 6
126
    PIN_MEMORY = True
127
128
    # ensure reproducibility of results
129
    SEED = 42
130
    np.random.seed(SEED)
131
    torch.manual_seed(SEED)
132
    print("[MAIN] Numpy and PyTorch seed set to {} for reproducibility.".format(SEED))
133
134
    MASTER_DATASET_INDEX = pd.read_csv("master_metadata_index.csv", dtype={"patient_ID":str, })
135
    subjects = MASTER_DATASET_INDEX["patient_ID"].astype("str").unique()
136
    print("[MAIN] Subject list fetched! Total subjects are {}...".format(len(subjects)))
137
138
    # NOTE: splitting whole subjects into train+validation and heldout test
139
    train_val_subjects, test_subjects = train_test_split(subjects, test_size=0.30, random_state=SEED)
140
    print("[MAIN] (Train + validation) and (heldout test) split made at subject level. 30 percent subjects held out for testing.")  
141
    train_subjects, val_subjects = train_test_split(train_val_subjects, test_size=0.20, random_state=SEED)
142
    train_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(train_subjects)].tolist()
143
    val_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(val_subjects)].tolist()
144
145
    # use GPU when available
146
    DEVICE = torch.device('cuda:{}'.format(GPU_IDX) if torch.cuda.is_available() else 'cpu')
147
    torch.cuda.set_device(DEVICE)
148
    print('[MAIN] Using device:', DEVICE, torch.cuda.get_device_name(DEVICE))
149
    
150
    X = load("psd_features_data_X")
151
    y = load("labels_y")
152
153
    # normalize psd_features_data_X
154
    normd_x = []
155
    for i in range(len(y)):
156
        arr = X[i, :]
157
        arr = arr.reshape(1, -1)
158
        arr2 = preprocessing.normalize(arr)
159
        arr2 = arr2.reshape(48)
160
        normd_x.append(arr2)
161
    
162
    norm = np.array(normd_x)
163
    X = norm.reshape(len(y), 48)
164
165
    # get 0/1 labels for pytorch, ensure mapping is the same between train and test
166
    label_mapping, y = np.unique(y, return_inverse = True)
167
    print("[MAIN] unique labels to [0 1] mapping:", label_mapping)
168
169
    model = EEGGraphConvNet(reduced_sensors=False)
170
    model = model.to(DEVICE).double()
171
172
    labels_unique, counts = np.unique(y, return_counts=True)
173
174
    class_weights = np.array([1.0/x for x in counts])
175
    # provide weights for samples in the training set only      
176
    sample_weights = class_weights[y[train_indices]]
177
    # sampler needs to come up with training set size number of samples
178
    weighted_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_indices), replacement=True)
179
180
    # define training set
181
    train_dataset = EEGGraphDataset(X=X, y=y, indices=train_indices, loader_type="train", 
182
                                    sfreq=SFREQ, transform=Compose([ToTensor()]))
183
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler,
184
                             num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
185
    
186
    # define validation set
187
    val_dataset = EEGGraphDataset(X=X, y=y, indices=val_indices, loader_type="validation", 
188
                                    sfreq=SFREQ, transform=Compose([ToTensor()]))
189
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, 
190
                              shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
191
192
    # define loss function
193
    loss_function = torch.nn.CrossEntropyLoss()
194
    # define optimizer
195
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
196
    # define scheduler
197
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i*10 for i in range(1, 26)], gamma=0.1)
198
199
    # start training
200
    for epoch in range(NUM_EPOCHS):
201
202
        model.train()
203
        train_loss = []
204
        val_loss = []
205
206
        y_probs_train = torch.empty(0, 2).to(DEVICE)
207
208
        y_true_train = [ ]
209
        y_pred_train = [ ]
210
        window_indices_train = [ ]
211
212
        for batch_idx, batch in enumerate(train_loader):
213
214
            # send batch to GPU
215
            X_batch = batch.to(device=DEVICE, non_blocking=True)
216
            y_batch = torch.tensor(batch.y)
217
            y_batch = y_batch.to(device=DEVICE, non_blocking=True)
218
            window_indices_train += X_batch.dataset_idx.cpu().numpy().tolist()
219
            optimizer.zero_grad()
220
221
            # forward pass
222
            outputs = model(X_batch.x, X_batch.edge_index, X_batch.edge_attr, X_batch.batch).float()
223
            loss = loss_function(outputs, y_batch)
224
            train_loss.append(loss.item())
225
            # backward pass
226
            loss.backward()
227
228
            _, predicted = torch.max(outputs.data, 1)
229
            y_pred_train += predicted.cpu().numpy().tolist()
230
231
            # concatenate along 0th dimension
232
            y_probs_train = torch.cat((y_probs_train, outputs.data), 0)
233
            y_true_train += y_batch.cpu().numpy().tolist()
234
235
            optimizer.step()
236
        scheduler.step()
237
238
        # returning prob distribution over target classes, take softmax across the 1st dimension
239
        y_probs_train = torch.nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
240
        y_true_train = np.array(y_true_train)
241
242
        # calculate training set metrics
243
        auroc_patient_train, auroc_train, precision_patient_train, recall_patient_train, f1_patient_train, bal_acc_patient_train = collect_metrics(y_probs_test=y_probs_train,
244
                        y_true_test=y_true_train,
245
                        y_pred_test=y_pred_train,
246
                        sample_indices_test = window_indices_train,                 
247
                        fold_idx=0,
248
                        experiment_name=EXPERIMENT_NAME)
249
        
250
        # evaluate on validation set
251
        model.eval()
252
        with torch.no_grad():
253
            y_probs_val = torch.empty(0, 2).to(DEVICE)
254
255
            y_true_val = [ ]
256
            y_pred_val = [ ]
257
            window_indices_val = [ ]
258
259
            for i, batch in enumerate(val_loader):
260
                X_batch = batch.to(device=DEVICE, non_blocking=True)
261
                y_batch = torch.tensor(batch.y)
262
                y_batch = y_batch.to(device=DEVICE, non_blocking=True)
263
                window_indices_val += X_batch.dataset_idx.cpu().numpy().tolist()
264
                outputs = model(X_batch.x, X_batch.edge_index, X_batch.edge_attr, X_batch.batch).float()
265
266
                loss = loss_function(outputs, y_batch)
267
                val_loss.append(loss.item())
268
269
                _, predicted = torch.max(outputs.data, 1)
270
                y_pred_val += predicted.cpu().numpy().tolist()
271
272
                # concatenate along 0th dimension
273
                y_probs_val = torch.cat((y_probs_val, outputs.data), 0)
274
                y_true_val += y_batch.cpu().numpy().tolist()
275
276
        # returning prob distribution over target classes, take softmax across the 1st dimension
277
        y_probs_val = torch.nn.functional.softmax(y_probs_val, dim=1).cpu().numpy()
278
        y_true_val = np.array(y_true_val)
279
280
        # get validation set metrics
281
        auroc_patient_val, auroc_val, precision_patient_val, recall_patient_val, f1_patient_val, bal_acc_patient_val = collect_metrics(y_probs_test=y_probs_val,
282
                        y_true_test=y_true_val,
283
                        y_pred_test=y_pred_val,
284
                        sample_indices_test = val_indices,                  
285
                        fold_idx=0,
286
                        experiment_name=EXPERIMENT_NAME)
287
        
288
        # save the model every 20 epochs
289
        if epoch % 20 == 0:
290
            state = {
291
                'model_description': str(model),
292
                'state_dict': model.state_dict(),
293
                'optimizer': optimizer.state_dict()
294
            }
295
296
            torch.save(state, f"model_{epoch}.ckpt")
297
298
        print(f'Epoch: {epoch}-----------------------------------------------------------')
299
        print(f"Train loss: {np.mean(train_loss):.3f}; Validation loss: {np.mean(val_loss):.3f}")
300
        print(f"Train AUROC:{auroc_train:.3f}; Validation AUROC: {auroc_val:.3f}")
301
        print(f"Train patient metrics: AUROC{auroc_patient_train:.3f}, precision: {precision_patient_train:.3f}, recall: {recall_patient_train:.3f}, f1: {f1_patient_train:.3f}, bal acc: {bal_acc_patient_train:.3f}")
302
        print(f"Validation patient metrics: AUROC{auroc_patient_val:.3f}, precision: {precision_patient_val:.3f}, recall: {recall_patient_val:.3f}, f1: {f1_patient_val:.3f}, bal acc: {bal_acc_patient_val:.3f}")