Diff of /training_pipeline.py [000000] .. [4dadda]

Switch to side-by-side view

--- a
+++ b/training_pipeline.py
@@ -0,0 +1,302 @@
+from sklearn.model_selection import train_test_split
+import pandas as pd
+import numpy as np
+import torch
+from joblib import load
+import statistics as stats
+from sklearn import preprocessing
+
+import torch.backends.cudnn as cudnn
+cudnn.enabled = True
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+from code_psd_shallow_eeg_gcnn.EEGGraphDataset import EEGGraphDataset
+from code_psd_shallow_eeg_gcnn.EEGGraphConvNet import EEGGraphConvNet
+from torch_geometric.data import DataLoader
+from torch.utils.data import WeightedRandomSampler
+from sklearn.metrics import make_scorer
+from sklearn.metrics import balanced_accuracy_score, auc, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
+from torchvision.transforms import Compose, ToTensor
+
+stats_test_data = { }
+
+# after each epoch, record all the metrics on both train and validation sets
+def collect_metrics(y_probs_test, y_true_test, y_pred_test, sample_indices_test,
+					fold_idx, experiment_name):
+
+	dataset_index = pd.read_csv("master_metadata_index.csv", dtype={"patient_ID":str, })
+
+	# create patient-level train and test dataframes
+	rows = [ ]
+	for i in range(len(sample_indices_test)):
+		idx = sample_indices_test[i]
+		temp = { }
+		temp["patient_ID"] = str(dataset_index.loc[idx, "patient_ID"])
+		temp["sample_idx"] = idx
+		temp["y_true"] = y_true_test[i]
+		temp["y_probs_0"] = y_probs_test[i, 0]
+		temp["y_probs_1"] = y_probs_test[i, 1]
+		temp["y_pred"] = y_pred_test[i]
+		rows.append(temp)
+	test_patient_df = pd.DataFrame(rows)
+
+	# get patient-level metrics from window-level dataframes
+	y_probs_test_patient, y_true_test_patient, y_pred_test_patient = get_patient_prediction(test_patient_df, fold_idx)
+
+	stats_test_data[f"probs_0_fold_{fold_idx}"] = y_probs_test_patient[:, 0]
+	stats_test_data[f"probs_1_fold_{fold_idx}"] = y_probs_test_patient[:, 1]
+
+	window_csv_dict = { }
+	patient_csv_dict = { }
+
+	# WINDOW-LEVEL ROC PLOT
+	# pos_label="healthy"
+	fpr, tpr, thresholds = roc_curve(y_true_test, y_probs_test[:,1], pos_label=1)
+	window_csv_dict[f"fpr_fold_{fold_idx}"] = fpr
+	window_csv_dict[f"tpr_fold_{fold_idx}"] = tpr
+	window_csv_dict[f"thres_fold_{fold_idx}"] = thresholds
+
+	# PATIENT-LEVEL ROC PLOT - select optimal threshold for this, and get patient-level precision, recall, f1
+	# pos_label="healthy"
+	fpr, tpr, thresholds = roc_curve(y_true_test_patient, y_probs_test_patient[:,1], pos_label=1)
+	patient_csv_dict[f"fpr_fold_{fold_idx}"] = fpr
+	patient_csv_dict[f"tpr_fold_{fold_idx}"] = tpr
+	patient_csv_dict[f"thres_fold_{fold_idx}"] = thresholds
+
+	# select an optimal threshold using the ROC curve
+	# Youden's J statistic to obtain the optimal probability threshold and this method gives equal weights to both false positives and false negatives
+	optimal_proba_cutoff = sorted(list(zip(np.abs(tpr - fpr), thresholds)), key=lambda i: i[0], reverse=True)[0][1]
+	# print (optimal_proba_cutoff)
+
+	# calculate class predictions and confusion-based metrics using the optimal threshold
+	roc_predictions = [1 if i >= optimal_proba_cutoff else 0 for i in y_probs_test_patient[:,1]]
+
+	precision_patient_test =  precision_score(y_true_test_patient, roc_predictions, pos_label=0)
+	recall_patient_test =  recall_score(y_true_test_patient, roc_predictions, pos_label=0)
+	f1_patient_test = f1_score(y_true_test_patient, roc_predictions, pos_label=0)
+	bal_acc_patient_test = balanced_accuracy_score(y_true_test_patient, roc_predictions)
+
+
+	# PATIENT-LEVEL AUROC
+	from sklearn.metrics import roc_auc_score
+	auroc_patient_test = roc_auc_score(y_true_test_patient, y_probs_test_patient[:,1])
+
+	# AUROC
+	from sklearn.metrics import roc_auc_score
+	# CAUTION - The binary case expects a shape (n_samples,), and the scores must be the scores of the class with the greater label.
+	# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
+	auroc_test = roc_auc_score(y_true_test, y_probs_test[:,1])
+	
+	return auroc_patient_test, auroc_test, precision_patient_test, recall_patient_test, f1_patient_test, bal_acc_patient_test
+
+# create patient-level metrics
+def get_patient_prediction(df, fold_idx):
+	unique_patients = list(df["patient_ID"].unique())
+	grouped_df = df.groupby("patient_ID")
+	rows = [ ]
+	for patient in unique_patients:
+		patient_df = grouped_df.get_group(patient)
+		temp = { }
+		temp["patient_ID"] = patient
+		temp["y_true"] = list(patient_df["y_true"].unique())[0]
+		assert len(list(patient_df["y_true"].unique())) == 1
+		temp["y_pred"] = patient_df["y_pred"].mode()[0]
+		temp["y_probs_0"] = patient_df["y_probs_0"].mean()
+		temp["y_probs_1"] = patient_df["y_probs_1"].mean()
+		rows.append(temp)
+	return_df = pd.DataFrame(rows)
+
+	# need subject names and labels for comparisons testing
+	if fold_idx == 0:
+		stats_test_data["subject_id"] = list(return_df["patient_ID"][:])
+		stats_test_data["label"] = return_df["y_true"][:]
+
+	return np.array(list(zip(return_df["y_probs_0"], return_df["y_probs_1"]))), list(return_df["y_true"]), list(return_df["y_pred"])
+
+
+if __name__ == "__main__":
+
+	GPU_IDX = 0
+	EXPERIMENT_NAME = "psd_gnn_shallow"
+	BATCH_SIZE = 512
+	SFREQ = 250.0
+	NUM_EPOCHS = 100
+	NUM_WORKERS = 6
+	PIN_MEMORY = True
+
+	# ensure reproducibility of results
+	SEED = 42
+	np.random.seed(SEED)
+	torch.manual_seed(SEED)
+	print("[MAIN] Numpy and PyTorch seed set to {} for reproducibility.".format(SEED))
+
+	MASTER_DATASET_INDEX = pd.read_csv("master_metadata_index.csv", dtype={"patient_ID":str, })
+	subjects = MASTER_DATASET_INDEX["patient_ID"].astype("str").unique()
+	print("[MAIN] Subject list fetched! Total subjects are {}...".format(len(subjects)))
+
+	# NOTE: splitting whole subjects into train+validation and heldout test
+	train_val_subjects, test_subjects = train_test_split(subjects, test_size=0.30, random_state=SEED)
+	print("[MAIN] (Train + validation) and (heldout test) split made at subject level. 30 percent subjects held out for testing.")	
+	train_subjects, val_subjects = train_test_split(train_val_subjects, test_size=0.20, random_state=SEED)
+	train_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(train_subjects)].tolist()
+	val_indices = MASTER_DATASET_INDEX.index[MASTER_DATASET_INDEX["patient_ID"].astype("str").isin(val_subjects)].tolist()
+
+	# use GPU when available
+	DEVICE = torch.device('cuda:{}'.format(GPU_IDX) if torch.cuda.is_available() else 'cpu')
+	torch.cuda.set_device(DEVICE)
+	print('[MAIN] Using device:', DEVICE, torch.cuda.get_device_name(DEVICE))
+	
+	X = load("psd_features_data_X")
+	y = load("labels_y")
+
+	# normalize psd_features_data_X
+	normd_x = []
+	for i in range(len(y)):
+		arr = X[i, :]
+		arr = arr.reshape(1, -1)
+		arr2 = preprocessing.normalize(arr)
+		arr2 = arr2.reshape(48)
+		normd_x.append(arr2)
+	
+	norm = np.array(normd_x)
+	X = norm.reshape(len(y), 48)
+
+	# get 0/1 labels for pytorch, ensure mapping is the same between train and test
+	label_mapping, y = np.unique(y, return_inverse = True)
+	print("[MAIN] unique labels to [0 1] mapping:", label_mapping)
+
+	model = EEGGraphConvNet(reduced_sensors=False)
+	model = model.to(DEVICE).double()
+
+	labels_unique, counts = np.unique(y, return_counts=True)
+
+	class_weights = np.array([1.0/x for x in counts])
+	# provide weights for samples in the training set only		
+	sample_weights = class_weights[y[train_indices]]
+	# sampler needs to come up with training set size number of samples
+	weighted_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_indices), replacement=True)
+
+	# define training set
+	train_dataset = EEGGraphDataset(X=X, y=y, indices=train_indices, loader_type="train", 
+									sfreq=SFREQ, transform=Compose([ToTensor()]))
+	train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler,
+							 num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
+	
+	# define validation set
+	val_dataset = EEGGraphDataset(X=X, y=y, indices=val_indices, loader_type="validation", 
+									sfreq=SFREQ, transform=Compose([ToTensor()]))
+	val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, 
+							  shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
+
+	# define loss function
+	loss_function = torch.nn.CrossEntropyLoss()
+	# define optimizer
+	optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
+	# define scheduler
+	scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i*10 for i in range(1, 26)], gamma=0.1)
+
+	# start training
+	for epoch in range(NUM_EPOCHS):
+
+		model.train()
+		train_loss = []
+		val_loss = []
+
+		y_probs_train = torch.empty(0, 2).to(DEVICE)
+
+		y_true_train = [ ]
+		y_pred_train = [ ]
+		window_indices_train = [ ]
+
+		for batch_idx, batch in enumerate(train_loader):
+
+			# send batch to GPU
+			X_batch = batch.to(device=DEVICE, non_blocking=True)
+			y_batch = torch.tensor(batch.y)
+			y_batch = y_batch.to(device=DEVICE, non_blocking=True)
+			window_indices_train += X_batch.dataset_idx.cpu().numpy().tolist()
+			optimizer.zero_grad()
+
+			# forward pass
+			outputs = model(X_batch.x, X_batch.edge_index, X_batch.edge_attr, X_batch.batch).float()
+			loss = loss_function(outputs, y_batch)
+			train_loss.append(loss.item())
+			# backward pass
+			loss.backward()
+
+			_, predicted = torch.max(outputs.data, 1)
+			y_pred_train += predicted.cpu().numpy().tolist()
+
+			# concatenate along 0th dimension
+			y_probs_train = torch.cat((y_probs_train, outputs.data), 0)
+			y_true_train += y_batch.cpu().numpy().tolist()
+
+			optimizer.step()
+		scheduler.step()
+
+		# returning prob distribution over target classes, take softmax across the 1st dimension
+		y_probs_train = torch.nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
+		y_true_train = np.array(y_true_train)
+
+		# calculate training set metrics
+		auroc_patient_train, auroc_train, precision_patient_train, recall_patient_train, f1_patient_train, bal_acc_patient_train = collect_metrics(y_probs_test=y_probs_train,
+						y_true_test=y_true_train,
+						y_pred_test=y_pred_train,
+						sample_indices_test = window_indices_train,					
+						fold_idx=0,
+						experiment_name=EXPERIMENT_NAME)
+		
+		# evaluate on validation set
+		model.eval()
+		with torch.no_grad():
+			y_probs_val = torch.empty(0, 2).to(DEVICE)
+
+			y_true_val = [ ]
+			y_pred_val = [ ]
+			window_indices_val = [ ]
+
+			for i, batch in enumerate(val_loader):
+				X_batch = batch.to(device=DEVICE, non_blocking=True)
+				y_batch = torch.tensor(batch.y)
+				y_batch = y_batch.to(device=DEVICE, non_blocking=True)
+				window_indices_val += X_batch.dataset_idx.cpu().numpy().tolist()
+				outputs = model(X_batch.x, X_batch.edge_index, X_batch.edge_attr, X_batch.batch).float()
+
+				loss = loss_function(outputs, y_batch)
+				val_loss.append(loss.item())
+
+				_, predicted = torch.max(outputs.data, 1)
+				y_pred_val += predicted.cpu().numpy().tolist()
+
+				# concatenate along 0th dimension
+				y_probs_val = torch.cat((y_probs_val, outputs.data), 0)
+				y_true_val += y_batch.cpu().numpy().tolist()
+
+		# returning prob distribution over target classes, take softmax across the 1st dimension
+		y_probs_val = torch.nn.functional.softmax(y_probs_val, dim=1).cpu().numpy()
+		y_true_val = np.array(y_true_val)
+
+		# get validation set metrics
+		auroc_patient_val, auroc_val, precision_patient_val, recall_patient_val, f1_patient_val, bal_acc_patient_val = collect_metrics(y_probs_test=y_probs_val,
+						y_true_test=y_true_val,
+						y_pred_test=y_pred_val,
+						sample_indices_test = val_indices,					
+						fold_idx=0,
+						experiment_name=EXPERIMENT_NAME)
+		
+		# save the model every 20 epochs
+		if epoch % 20 == 0:
+			state = {
+				'model_description': str(model),
+				'state_dict': model.state_dict(),
+				'optimizer': optimizer.state_dict()
+			}
+
+			torch.save(state, f"model_{epoch}.ckpt")
+
+		print(f'Epoch: {epoch}-----------------------------------------------------------')
+		print(f"Train loss: {np.mean(train_loss):.3f}; Validation loss: {np.mean(val_loss):.3f}")
+		print(f"Train AUROC:{auroc_train:.3f}; Validation AUROC: {auroc_val:.3f}")
+		print(f"Train patient metrics: AUROC{auroc_patient_train:.3f}, precision: {precision_patient_train:.3f}, recall: {recall_patient_train:.3f}, f1: {f1_patient_train:.3f}, bal acc: {bal_acc_patient_train:.3f}")
+		print(f"Validation patient metrics: AUROC{auroc_patient_val:.3f}, precision: {precision_patient_val:.3f}, recall: {recall_patient_val:.3f}, f1: {f1_patient_val:.3f}, bal acc: {bal_acc_patient_val:.3f}")