a b/train_test.py
1
""" Training and testing of the model
2
"""
3
import os
4
import numpy as np
5
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
6
import torch
7
import torch.nn.functional as F
8
from models import init_model_dict, init_optim
9
from utils import one_hot_tensor, cal_sample_weight, gen_adj_mat_tensor, gen_test_adj_mat_tensor, cal_adj_mat_parameter
10
11
cuda = True if torch.cuda.is_available() else False
12
13
14
def prepare_trte_data(data_folder, view_list):
15
    num_view = len(view_list)
16
    labels_tr = np.loadtxt(os.path.join(data_folder, "labels_tr.csv"), delimiter=',')
17
    labels_te = np.loadtxt(os.path.join(data_folder, "labels_te.csv"), delimiter=',')
18
    labels_tr = labels_tr.astype(int)
19
    labels_te = labels_te.astype(int)
20
    data_tr_list = []
21
    data_te_list = []
22
    for i in view_list:
23
        data_tr_list.append(np.loadtxt(os.path.join(data_folder, str(i)+"_tr.csv"), delimiter=','))
24
        data_te_list.append(np.loadtxt(os.path.join(data_folder, str(i)+"_te.csv"), delimiter=','))
25
    num_tr = data_tr_list[0].shape[0]
26
    num_te = data_te_list[0].shape[0]
27
    data_mat_list = []
28
    for i in range(num_view):
29
        data_mat_list.append(np.concatenate((data_tr_list[i], data_te_list[i]), axis=0))
30
    data_tensor_list = []
31
    for i in range(len(data_mat_list)):
32
        data_tensor_list.append(torch.FloatTensor(data_mat_list[i]))
33
        if cuda:
34
            data_tensor_list[i] = data_tensor_list[i].cuda()
35
    idx_dict = {}
36
    idx_dict["tr"] = list(range(num_tr))
37
    idx_dict["te"] = list(range(num_tr, (num_tr+num_te)))
38
    data_train_list = []
39
    data_all_list = []
40
    for i in range(len(data_tensor_list)):
41
        data_train_list.append(data_tensor_list[i][idx_dict["tr"]].clone())
42
        data_all_list.append(torch.cat((data_tensor_list[i][idx_dict["tr"]].clone(),
43
                                       data_tensor_list[i][idx_dict["te"]].clone()),0))
44
    labels = np.concatenate((labels_tr, labels_te))
45
    
46
    return data_train_list, data_all_list, idx_dict, labels
47
48
49
def gen_trte_adj_mat(data_tr_list, data_trte_list, trte_idx, adj_parameter):
50
    adj_metric = "cosine" # cosine distance
51
    adj_train_list = []
52
    adj_test_list = []
53
    for i in range(len(data_tr_list)):
54
        adj_parameter_adaptive = cal_adj_mat_parameter(adj_parameter, data_tr_list[i], adj_metric)
55
        adj_train_list.append(gen_adj_mat_tensor(data_tr_list[i], adj_parameter_adaptive, adj_metric))
56
        adj_test_list.append(gen_test_adj_mat_tensor(data_trte_list[i], trte_idx, adj_parameter_adaptive, adj_metric))
57
    
58
    return adj_train_list, adj_test_list
59
60
61
def train_epoch(data_list, adj_list, label, one_hot_label, sample_weight, model_dict, optim_dict, train_VCDN=True):
62
    loss_dict = {}
63
    criterion = torch.nn.CrossEntropyLoss(reduction='none')
64
    for m in model_dict:
65
        model_dict[m].train()    
66
    num_view = len(data_list)
67
    for i in range(num_view):
68
        optim_dict["C{:}".format(i+1)].zero_grad()
69
        ci_loss = 0
70
        ci = model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i]))
71
        ci_loss = torch.mean(torch.mul(criterion(ci, label),sample_weight))
72
        ci_loss.backward()
73
        optim_dict["C{:}".format(i+1)].step()
74
        loss_dict["C{:}".format(i+1)] = ci_loss.detach().cpu().numpy().item()
75
    if train_VCDN and num_view >= 2:
76
        optim_dict["C"].zero_grad()
77
        c_loss = 0
78
        ci_list = []
79
        for i in range(num_view):
80
            ci_list.append(model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i])))
81
        c = model_dict["C"](ci_list)    
82
        c_loss = torch.mean(torch.mul(criterion(c, label),sample_weight))
83
        c_loss.backward()
84
        optim_dict["C"].step()
85
        loss_dict["C"] = c_loss.detach().cpu().numpy().item()
86
    
87
    return loss_dict
88
    
89
90
def test_epoch(data_list, adj_list, te_idx, model_dict):
91
    for m in model_dict:
92
        model_dict[m].eval()
93
    num_view = len(data_list)
94
    ci_list = []
95
    for i in range(num_view):
96
        ci_list.append(model_dict["C{:}".format(i+1)](model_dict["E{:}".format(i+1)](data_list[i],adj_list[i])))
97
    if num_view >= 2:
98
        c = model_dict["C"](ci_list)    
99
    else:
100
        c = ci_list[0]
101
    c = c[te_idx,:]
102
    prob = F.softmax(c, dim=1).data.cpu().numpy()
103
    
104
    return prob
105
106
107
def train_test(data_folder, view_list, num_class,
108
               lr_e_pretrain, lr_e, lr_c, 
109
               num_epoch_pretrain, num_epoch):
110
    test_inverval = 50
111
    num_view = len(view_list)
112
    dim_hvcdn = pow(num_class,num_view)
113
    if data_folder == 'ROSMAP':
114
        adj_parameter = 2
115
        dim_he_list = [200,200,100]
116
    if data_folder == 'BRCA':
117
        adj_parameter = 10
118
        dim_he_list = [400,400,200]
119
    data_tr_list, data_trte_list, trte_idx, labels_trte = prepare_trte_data(data_folder, view_list)
120
    labels_tr_tensor = torch.LongTensor(labels_trte[trte_idx["tr"]])
121
    onehot_labels_tr_tensor = one_hot_tensor(labels_tr_tensor, num_class)
122
    sample_weight_tr = cal_sample_weight(labels_trte[trte_idx["tr"]], num_class)
123
    sample_weight_tr = torch.FloatTensor(sample_weight_tr)
124
    if cuda:
125
        labels_tr_tensor = labels_tr_tensor.cuda()
126
        onehot_labels_tr_tensor = onehot_labels_tr_tensor.cuda()
127
        sample_weight_tr = sample_weight_tr.cuda()
128
    adj_tr_list, adj_te_list = gen_trte_adj_mat(data_tr_list, data_trte_list, trte_idx, adj_parameter)
129
    dim_list = [x.shape[1] for x in data_tr_list]
130
    model_dict = init_model_dict(num_view, num_class, dim_list, dim_he_list, dim_hvcdn)
131
    for m in model_dict:
132
        if cuda:
133
            model_dict[m].cuda()
134
    
135
    print("\nPretrain GCNs...")
136
    optim_dict = init_optim(num_view, model_dict, lr_e_pretrain, lr_c)
137
    for epoch in range(num_epoch_pretrain):
138
        train_epoch(data_tr_list, adj_tr_list, labels_tr_tensor, 
139
                    onehot_labels_tr_tensor, sample_weight_tr, model_dict, optim_dict, train_VCDN=False)
140
    print("\nTraining...")
141
    optim_dict = init_optim(num_view, model_dict, lr_e, lr_c)
142
    for epoch in range(num_epoch+1):
143
        train_epoch(data_tr_list, adj_tr_list, labels_tr_tensor, 
144
                    onehot_labels_tr_tensor, sample_weight_tr, model_dict, optim_dict)
145
        if epoch % test_inverval == 0:
146
            te_prob = test_epoch(data_trte_list, adj_te_list, trte_idx["te"], model_dict)
147
            print("\nTest: Epoch {:d}".format(epoch))
148
            if num_class == 2:
149
                print("Test ACC: {:.3f}".format(accuracy_score(labels_trte[trte_idx["te"]], te_prob.argmax(1))))
150
                print("Test F1: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1))))
151
                print("Test AUC: {:.3f}".format(roc_auc_score(labels_trte[trte_idx["te"]], te_prob[:,1])))
152
            else:
153
                print("Test ACC: {:.3f}".format(accuracy_score(labels_trte[trte_idx["te"]], te_prob.argmax(1))))
154
                print("Test F1 weighted: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1), average='weighted')))
155
                print("Test F1 macro: {:.3f}".format(f1_score(labels_trte[trte_idx["te"]], te_prob.argmax(1), average='macro')))
156
            print()