--- a +++ b/train_test.py @@ -0,0 +1,156 @@ +""" Training and testing of the model +""" +import os +import numpy as np +from sklearn.metrics import accuracy_score, f1_score, roc_auc_score +import torch +import torch.nn.functional as F +from models import init_model_dict, init_optim +from utils import one_hot_tensor, cal_sample_weight, gen_adj_mat_tensor, gen_test_adj_mat_tensor, cal_adj_mat_parameter + +cuda = True if torch.cuda.is_available() else False + + +def prepare_trte_data(data_folder, view_list): + num_view = len(view_list) + labels_tr = np.loadtxt(os.path.join(data_folder, "labels_tr.csv"), delimiter=',') + labels_te = np.loadtxt(os.path.join(data_folder, "labels_te.csv"), delimiter=',') + labels_tr = labels_tr.astype(int) + labels_te = labels_te.astype(int) + data_tr_list = [] + data_te_list = [] + for i in view_list: + data_tr_list.append(np.loadtxt(os.path.join(data_folder, str(i)+"_tr.csv"), delimiter=',')) + data_te_list.append(np.loadtxt(os.path.join(data_folder, str(i)+"_te.csv"), delimiter=',')) + num_tr = data_tr_list[0].shape[0] + num_te = data_te_list[0].shape[0] + data_mat_list = [] + for i in range(num_view): + data_mat_list.append(np.concatenate((data_tr_list[i], data_te_list[i]), axis=0)) + data_tensor_list = [] + for i in range(len(data_mat_list)): + data_tensor_list.append(torch.FloatTensor(data_mat_list[i])) + if cuda: + data_tensor_list[i] = data_tensor_list[i].cuda() + idx_dict = {} + idx_dict["tr"] = list(range(num_tr)) + idx_dict["te"] = list(range(num_tr, (num_tr+num_te))) + data_train_list = [] + data_all_list = [] + for i in range(len(data_tensor_list)): + data_train_list.append(data_tensor_list[i][idx_dict["tr"]].clone()) + data_all_list.append(torch.cat((data_tensor_list[i][idx_dict["tr"]].clone(), + data_tensor_list[i][idx_dict["te"]].clone()),0)) + labels = np.concatenate((labels_tr, labels_te)) + + return data_train_list, data_all_list, idx_dict, labels + + +def gen_trte_adj_mat(data_tr_list, data_trte_list, trte_idx, adj_parameter): + adj_metric = "cosine" # cosine distance + adj_train_list = [] + adj_test_list = [] + for i in range(len(data_tr_list)): + adj_parameter_adaptive = cal_adj_mat_parameter(adj_parameter, data_tr_list[i], adj_metric) + adj_train_list.append(gen_adj_mat_tensor(data_tr_list[i], adj_parameter_adaptive, adj_metric)) + adj_test_list.append(gen_test_adj_mat_tensor(data_trte_list[i], trte_idx, adj_parameter_adaptive, adj_metric)) + + return adj_train_list, adj_test_list + + +def train_epoch(data_list, adj_list, label, one_hot_label, sample_weight, model_dict, optim_dict, train_VCDN=True): + loss_dict = {} + criterion = torch.nn.CrossEntropyLoss(reduction='none') + for m in model_dict: + model_dict[m].train() + num_view = len(data_list) + for i in range(num_view): + optim_dict["C{:}".format(i+1)].zero_grad() + ci_loss = 0 + ci = model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i])) + ci_loss = torch.mean(torch.mul(criterion(ci, label),sample_weight)) + ci_loss.backward() + optim_dict["C{:}".format(i+1)].step() + loss_dict["C{:}".format(i+1)] = ci_loss.detach().cpu().numpy().item() + if train_VCDN and num_view >= 2: + optim_dict["C"].zero_grad() + c_loss = 0 + ci_list = [] + for i in range(num_view): + ci_list.append(model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i]))) + c = model_dict["C"](ci_list) + c_loss = torch.mean(torch.mul(criterion(c, label),sample_weight)) + c_loss.backward() + optim_dict["C"].step() + loss_dict["C"] = c_loss.detach().cpu().numpy().item() + + return loss_dict + + +def test_epoch(data_list, adj_list, te_idx, model_dict): + for m in model_dict: + model_dict[m].eval() + num_view = len(data_list) + ci_list = [] + for i in range(num_view): + ci_list.append(model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i]))) + if num_view >= 2: + c = model_dict["C"](ci_list) + else: + c = ci_list[0] + c = c[te_idx,:] + prob = F.softmax(c, dim=1).data.cpu().numpy() + + return prob + + +def train_test(data_folder, view_list, num_class, + lr_e_pretrain, lr_e, lr_c, + num_epoch_pretrain, num_epoch): + test_inverval = 50 + num_view = len(view_list) + dim_hvcdn = pow(num_class,num_view) + if data_folder == 'ROSMAP': + adj_parameter = 2 + dim_he_list = [200,200,100] + if data_folder == 'BRCA': + adj_parameter = 10 + dim_he_list = [400,400,200] + data_tr_list, data_trte_list, trte_idx, labels_trte = prepare_trte_data(data_folder, view_list) + labels_tr_tensor = torch.LongTensor(labels_trte[trte_idx["tr"]]) + onehot_labels_tr_tensor = one_hot_tensor(labels_tr_tensor, num_class) + sample_weight_tr = cal_sample_weight(labels_trte[trte_idx["tr"]], num_class) + sample_weight_tr = torch.FloatTensor(sample_weight_tr) + if cuda: + labels_tr_tensor = labels_tr_tensor.cuda() + onehot_labels_tr_tensor = onehot_labels_tr_tensor.cuda() + sample_weight_tr = sample_weight_tr.cuda() + adj_tr_list, adj_te_list = gen_trte_adj_mat(data_tr_list, data_trte_list, trte_idx, adj_parameter) + dim_list = [x.shape[1] for x in data_tr_list] + model_dict = init_model_dict(num_view, num_class, dim_list, dim_he_list, dim_hvcdn) + for m in model_dict: + if cuda: + model_dict[m].cuda() + + print("\nPretrain GCNs...") + optim_dict = init_optim(num_view, model_dict, lr_e_pretrain, lr_c) + for epoch in range(num_epoch_pretrain): + train_epoch(data_tr_list, adj_tr_list, labels_tr_tensor, + onehot_labels_tr_tensor, sample_weight_tr, model_dict, optim_dict, train_VCDN=False) + print("\nTraining...") + optim_dict = init_optim(num_view, model_dict, lr_e, lr_c) + for epoch in range(num_epoch+1): + train_epoch(data_tr_list, adj_tr_list, labels_tr_tensor, + onehot_labels_tr_tensor, sample_weight_tr, model_dict, optim_dict) + if epoch % test_inverval == 0: + te_prob = test_epoch(data_trte_list, adj_te_list, trte_idx["te"], model_dict) + print("\nTest: Epoch {:d}".format(epoch)) + if num_class == 2: + print("Test ACC: {:.3f}".format(accuracy_score(labels_trte[trte_idx["te"]], te_prob.argmax(1)))) + print("Test F1: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1)))) + print("Test AUC: {:.3f}".format(roc_auc_score(labels_trte[trte_idx["te"]], te_prob[:,1]))) + else: + print("Test ACC: {:.3f}".format(accuracy_score(labels_trte[trte_idx["te"]], te_prob.argmax(1)))) + print("Test F1 weighted: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1), average='weighted'))) + print("Test F1 macro: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1), average='macro'))) + print() \ No newline at end of file